ggerganov commited on
Commit
511930c
·
unverified ·
1 Parent(s): 4b6e041

talk-llama : sync llama.cpp (#3084)

Browse files
Files changed (36) hide show
  1. examples/talk-llama/CMakeLists.txt +3 -0
  2. examples/talk-llama/llama-adapter.cpp +55 -20
  3. examples/talk-llama/llama-adapter.h +11 -9
  4. examples/talk-llama/llama-arch.cpp +264 -33
  5. examples/talk-llama/llama-arch.h +33 -0
  6. examples/talk-llama/llama-batch.h +2 -2
  7. examples/talk-llama/llama-chat.cpp +76 -2
  8. examples/talk-llama/llama-chat.h +4 -0
  9. examples/talk-llama/llama-context.cpp +0 -0
  10. examples/talk-llama/llama-context.h +214 -77
  11. examples/talk-llama/llama-cparams.h +1 -0
  12. examples/talk-llama/llama-grammar.cpp +183 -183
  13. examples/talk-llama/llama-grammar.h +13 -4
  14. examples/talk-llama/llama-graph.cpp +1706 -0
  15. examples/talk-llama/llama-graph.h +596 -0
  16. examples/talk-llama/llama-hparams.cpp +8 -0
  17. examples/talk-llama/llama-hparams.h +21 -0
  18. examples/talk-llama/llama-impl.h +6 -6
  19. examples/talk-llama/llama-io.cpp +15 -0
  20. examples/talk-llama/llama-io.h +35 -0
  21. examples/talk-llama/llama-kv-cache.cpp +965 -303
  22. examples/talk-llama/llama-kv-cache.h +145 -150
  23. examples/talk-llama/llama-memory.cpp +1 -0
  24. examples/talk-llama/llama-memory.h +21 -0
  25. examples/talk-llama/llama-mmap.cpp +11 -1
  26. examples/talk-llama/llama-mmap.h +1 -0
  27. examples/talk-llama/llama-model-loader.cpp +10 -5
  28. examples/talk-llama/llama-model-loader.h +5 -3
  29. examples/talk-llama/llama-model.cpp +0 -0
  30. examples/talk-llama/llama-model.h +42 -1
  31. examples/talk-llama/llama-quant.cpp +39 -9
  32. examples/talk-llama/llama-sampling.cpp +179 -67
  33. examples/talk-llama/llama-vocab.cpp +55 -5
  34. examples/talk-llama/llama.cpp +0 -0
  35. examples/talk-llama/llama.h +147 -47
  36. examples/talk-llama/unicode.cpp +9 -2
examples/talk-llama/CMakeLists.txt CHANGED
@@ -12,9 +12,12 @@ if (WHISPER_SDL2)
12
  llama-context.cpp
13
  llama-cparams.cpp
14
  llama-grammar.cpp
 
15
  llama-hparams.cpp
16
  llama-impl.cpp
 
17
  llama-kv-cache.cpp
 
18
  llama-mmap.cpp
19
  llama-model-loader.cpp
20
  llama-model.cpp
 
12
  llama-context.cpp
13
  llama-cparams.cpp
14
  llama-grammar.cpp
15
+ llama-graph.cpp
16
  llama-hparams.cpp
17
  llama-impl.cpp
18
+ llama-io.cpp
19
  llama-kv-cache.cpp
20
+ llama-memory.cpp
21
  llama-mmap.cpp
22
  llama-model-loader.cpp
23
  llama-model.cpp
examples/talk-llama/llama-adapter.cpp CHANGED
@@ -4,14 +4,13 @@
4
  #include "llama-mmap.h"
5
  #include "llama-model.h"
6
 
7
- #include <algorithm>
8
  #include <map>
9
  #include <cassert>
10
  #include <stdexcept>
11
 
12
  // vec
13
 
14
- struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
15
  if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
16
  return nullptr;
17
  }
@@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
19
  return tensors[il];
20
  }
21
 
22
- struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
23
  ggml_tensor * layer_dir = tensor_for(il);
24
  if (layer_dir != nullptr) {
25
  cur = ggml_add(ctx, cur, layer_dir);
@@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
40
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
41
  auto it = ctx_map.find(buft);
42
  if (it == ctx_map.end()) {
43
- struct ggml_init_params params = {
44
  /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
45
  /*.mem_buffer =*/ NULL,
46
  /*.no_alloc =*/ true,
@@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
91
  return true;
92
  }
93
 
94
- int32_t llama_adapter_cvec::apply(
95
  const llama_model & model,
96
  const float * data,
97
  size_t len,
@@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
104
  // disable the current control vector (but leave allocated for later)
105
  layer_start = -1;
106
  layer_end = -1;
107
- return 0;
108
  }
109
 
110
  if (n_embd != (int) hparams.n_embd) {
111
  LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
112
- return 1;
113
  }
114
 
115
  if (tensors.empty()) {
116
  if (!init(model)) {
117
- return 1;
118
  }
119
  }
120
 
@@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
130
  }
131
  }
132
 
133
- return 0;
134
  }
135
 
136
  // lora
137
 
138
- llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
139
  const std::string name(w->name);
140
 
141
  const auto pos = ab_map.find(name);
@@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
146
  return nullptr;
147
  }
148
 
149
- static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
150
  LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
151
 
152
  ggml_context * ctx_init;
153
- struct gguf_init_params meta_gguf_params = {
154
  /* .no_alloc = */ true,
155
  /* .ctx = */ &ctx_init,
156
  };
@@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
201
  auto it = ctx_map.find(buft);
202
  if (it == ctx_map.end()) {
203
  // add a new context
204
- struct ggml_init_params params = {
205
  /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
206
  /*.mem_buffer =*/ NULL,
207
  /*.no_alloc =*/ true,
@@ -248,6 +247,26 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
248
  }
249
  }
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  // add tensors
252
  for (auto & it : ab_map) {
253
  const std::string & name = it.first;
@@ -264,7 +283,23 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
264
  throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
265
  }
266
 
267
- struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  // validate tensor shape
269
  if (is_token_embd) {
270
  // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@@ -281,8 +316,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
281
  }
282
 
283
  // save tensor to adapter
284
- struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
285
- struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
286
  ggml_set_name(tensor_a, w.a->name);
287
  ggml_set_name(tensor_b, w.b->name);
288
  adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@@ -308,7 +343,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
308
  {
309
  llama_file gguf_file(path_lora, "rb");
310
  std::vector<uint8_t> read_buf;
311
- auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
312
  size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
313
  size_t size = ggml_nbytes(orig);
314
  read_buf.resize(size);
@@ -327,8 +362,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
327
  LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
328
  }
329
 
330
- struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
331
- struct llama_adapter_lora * adapter = new llama_adapter_lora();
332
 
333
  try {
334
  llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@@ -342,6 +377,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
342
  return nullptr;
343
  }
344
 
345
- void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
346
  delete adapter;
347
  }
 
4
  #include "llama-mmap.h"
5
  #include "llama-model.h"
6
 
 
7
  #include <map>
8
  #include <cassert>
9
  #include <stdexcept>
10
 
11
  // vec
12
 
13
+ ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
14
  if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
15
  return nullptr;
16
  }
 
18
  return tensors[il];
19
  }
20
 
21
+ ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
22
  ggml_tensor * layer_dir = tensor_for(il);
23
  if (layer_dir != nullptr) {
24
  cur = ggml_add(ctx, cur, layer_dir);
 
39
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
40
  auto it = ctx_map.find(buft);
41
  if (it == ctx_map.end()) {
42
+ ggml_init_params params = {
43
  /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
44
  /*.mem_buffer =*/ NULL,
45
  /*.no_alloc =*/ true,
 
90
  return true;
91
  }
92
 
93
+ bool llama_adapter_cvec::apply(
94
  const llama_model & model,
95
  const float * data,
96
  size_t len,
 
103
  // disable the current control vector (but leave allocated for later)
104
  layer_start = -1;
105
  layer_end = -1;
106
+ return true;
107
  }
108
 
109
  if (n_embd != (int) hparams.n_embd) {
110
  LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
111
+ return false;
112
  }
113
 
114
  if (tensors.empty()) {
115
  if (!init(model)) {
116
+ return false;
117
  }
118
  }
119
 
 
129
  }
130
  }
131
 
132
+ return true;
133
  }
134
 
135
  // lora
136
 
137
+ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
138
  const std::string name(w->name);
139
 
140
  const auto pos = ab_map.find(name);
 
145
  return nullptr;
146
  }
147
 
148
+ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
149
  LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
150
 
151
  ggml_context * ctx_init;
152
+ gguf_init_params meta_gguf_params = {
153
  /* .no_alloc = */ true,
154
  /* .ctx = */ &ctx_init,
155
  };
 
200
  auto it = ctx_map.find(buft);
201
  if (it == ctx_map.end()) {
202
  // add a new context
203
+ ggml_init_params params = {
204
  /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
205
  /*.mem_buffer =*/ NULL,
206
  /*.no_alloc =*/ true,
 
247
  }
248
  }
249
 
250
+ // get extra buffer types of the CPU
251
+ // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
252
+ // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
253
+ std::vector<ggml_backend_buffer_type_t> buft_extra;
254
+ {
255
+ auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
256
+ auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
257
+
258
+ auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
259
+ ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
260
+
261
+ if (ggml_backend_dev_get_extra_bufts_fn) {
262
+ ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
263
+ while (extra_bufts && *extra_bufts) {
264
+ buft_extra.emplace_back(*extra_bufts);
265
+ ++extra_bufts;
266
+ }
267
+ }
268
+ }
269
+
270
  // add tensors
271
  for (auto & it : ab_map) {
272
  const std::string & name = it.first;
 
283
  throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
284
  }
285
 
286
+ auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
287
+
288
+ // do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
289
+ for (auto & ex : buft_extra) {
290
+ if (ex == buft) {
291
+ LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
292
+
293
+ auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
294
+ buft = ggml_backend_dev_buffer_type(cpu_dev);
295
+
296
+ break;
297
+ }
298
+ }
299
+
300
+ LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
301
+
302
+ ggml_context * dev_ctx = ctx_for_buft(buft);
303
  // validate tensor shape
304
  if (is_token_embd) {
305
  // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
 
316
  }
317
 
318
  // save tensor to adapter
319
+ ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
320
+ ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
321
  ggml_set_name(tensor_a, w.a->name);
322
  ggml_set_name(tensor_b, w.b->name);
323
  adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
 
343
  {
344
  llama_file gguf_file(path_lora, "rb");
345
  std::vector<uint8_t> read_buf;
346
+ auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
347
  size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
348
  size_t size = ggml_nbytes(orig);
349
  read_buf.resize(size);
 
362
  LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
363
  }
364
 
365
+ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
366
+ llama_adapter_lora * adapter = new llama_adapter_lora();
367
 
368
  try {
369
  llama_adapter_lora_init_impl(*model, path_lora, *adapter);
 
377
  return nullptr;
378
  }
379
 
380
+ void llama_adapter_lora_free(llama_adapter_lora * adapter) {
381
  delete adapter;
382
  }
examples/talk-llama/llama-adapter.h CHANGED
@@ -15,11 +15,11 @@
15
  //
16
 
17
  struct llama_adapter_cvec {
18
- struct ggml_tensor * tensor_for(int il) const;
19
 
20
- struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
21
 
22
- int32_t apply(
23
  const llama_model & model,
24
  const float * data,
25
  size_t len,
@@ -36,7 +36,7 @@ private:
36
  std::vector<ggml_context_ptr> ctxs;
37
  std::vector<ggml_backend_buffer_ptr> bufs;
38
 
39
- std::vector<struct ggml_tensor *> tensors; // per layer
40
  };
41
 
42
  //
@@ -44,8 +44,8 @@ private:
44
  //
45
 
46
  struct llama_adapter_lora_weight {
47
- struct ggml_tensor * a = nullptr;
48
- struct ggml_tensor * b = nullptr;
49
 
50
  // get actual scale based on rank and alpha
51
  float get_scale(float alpha, float adapter_scale) const {
@@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
55
  }
56
 
57
  llama_adapter_lora_weight() = default;
58
- llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
59
  };
60
 
61
  struct llama_adapter_lora {
62
  // map tensor name to lora_a_b
63
- std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
64
 
65
  std::vector<ggml_context_ptr> ctxs;
66
  std::vector<ggml_backend_buffer_ptr> bufs;
@@ -70,5 +70,7 @@ struct llama_adapter_lora {
70
  llama_adapter_lora() = default;
71
  ~llama_adapter_lora() = default;
72
 
73
- llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
74
  };
 
 
 
15
  //
16
 
17
  struct llama_adapter_cvec {
18
+ ggml_tensor * tensor_for(int il) const;
19
 
20
+ ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
21
 
22
+ bool apply(
23
  const llama_model & model,
24
  const float * data,
25
  size_t len,
 
36
  std::vector<ggml_context_ptr> ctxs;
37
  std::vector<ggml_backend_buffer_ptr> bufs;
38
 
39
+ std::vector<ggml_tensor *> tensors; // per layer
40
  };
41
 
42
  //
 
44
  //
45
 
46
  struct llama_adapter_lora_weight {
47
+ ggml_tensor * a = nullptr;
48
+ ggml_tensor * b = nullptr;
49
 
50
  // get actual scale based on rank and alpha
51
  float get_scale(float alpha, float adapter_scale) const {
 
55
  }
56
 
57
  llama_adapter_lora_weight() = default;
58
+ llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
59
  };
60
 
61
  struct llama_adapter_lora {
62
  // map tensor name to lora_a_b
63
+ std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
64
 
65
  std::vector<ggml_context_ptr> ctxs;
66
  std::vector<ggml_backend_buffer_ptr> bufs;
 
70
  llama_adapter_lora() = default;
71
  ~llama_adapter_lora() = default;
72
 
73
+ llama_adapter_lora_weight * get_weight(ggml_tensor * w);
74
  };
75
+
76
+ using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
examples/talk-llama/llama-arch.cpp CHANGED
@@ -6,6 +6,7 @@
6
 
7
  static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8
  { LLM_ARCH_LLAMA, "llama" },
 
9
  { LLM_ARCH_DECI, "deci" },
10
  { LLM_ARCH_FALCON, "falcon" },
11
  { LLM_ARCH_GROK, "grok" },
@@ -25,6 +26,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
25
  { LLM_ARCH_QWEN2, "qwen2" },
26
  { LLM_ARCH_QWEN2MOE, "qwen2moe" },
27
  { LLM_ARCH_QWEN2VL, "qwen2vl" },
 
 
28
  { LLM_ARCH_PHI2, "phi2" },
29
  { LLM_ARCH_PHI3, "phi3" },
30
  { LLM_ARCH_PHIMOE, "phimoe" },
@@ -36,6 +39,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
36
  { LLM_ARCH_MINICPM3, "minicpm3" },
37
  { LLM_ARCH_GEMMA, "gemma" },
38
  { LLM_ARCH_GEMMA2, "gemma2" },
 
39
  { LLM_ARCH_STARCODER2, "starcoder2" },
40
  { LLM_ARCH_MAMBA, "mamba" },
41
  { LLM_ARCH_XVERSE, "xverse" },
@@ -50,6 +54,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
50
  { LLM_ARCH_DEEPSEEK, "deepseek" },
51
  { LLM_ARCH_DEEPSEEK2, "deepseek2" },
52
  { LLM_ARCH_CHATGLM, "chatglm" },
 
53
  { LLM_ARCH_BITNET, "bitnet" },
54
  { LLM_ARCH_T5, "t5" },
55
  { LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -58,10 +63,14 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
58
  { LLM_ARCH_EXAONE, "exaone" },
59
  { LLM_ARCH_RWKV6, "rwkv6" },
60
  { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
 
 
61
  { LLM_ARCH_GRANITE, "granite" },
62
  { LLM_ARCH_GRANITE_MOE, "granitemoe" },
63
  { LLM_ARCH_CHAMELEON, "chameleon" },
64
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
 
 
65
  { LLM_ARCH_UNKNOWN, "(unknown)" },
66
  };
67
 
@@ -70,6 +79,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
70
  { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
71
  { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
72
  { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
 
73
  { LLM_KV_GENERAL_NAME, "general.name" },
74
  { LLM_KV_GENERAL_AUTHOR, "general.author" },
75
  { LLM_KV_GENERAL_VERSION, "general.version" },
@@ -108,23 +118,30 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
108
  { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
109
  { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
110
  { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
 
111
 
112
- { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
113
- { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
114
- { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
115
- { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
116
- { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
117
- { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
118
- { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
119
- { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
120
- { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
121
- { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
122
- { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
123
- { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
124
- { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
125
- { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
126
- { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
127
- { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
 
 
 
 
 
 
128
 
129
  { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
130
  { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -223,6 +240,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
223
  { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
224
  },
225
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  {
227
  LLM_ARCH_DECI,
228
  {
@@ -554,6 +600,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
554
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
555
  },
556
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  {
558
  LLM_ARCH_PHI2,
559
  {
@@ -766,6 +851,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
766
  { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
767
  },
768
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
  {
770
  LLM_ARCH_STARCODER2,
771
  {
@@ -999,6 +1105,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
999
  { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
1000
  { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
1001
  { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
 
 
1002
  { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1003
  { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1004
  { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1015,6 +1123,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1015
  { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1016
  },
1017
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
  {
1019
  LLM_ARCH_CHATGLM,
1020
  {
@@ -1033,6 +1157,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1033
  { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1034
  },
1035
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1036
  {
1037
  LLM_ARCH_BITNET,
1038
  {
@@ -1217,6 +1360,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1217
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1218
  },
1219
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1220
  {
1221
  LLM_ARCH_GRANITE,
1222
  {
@@ -1296,6 +1507,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1296
  { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
1297
  },
1298
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1299
  {
1300
  LLM_ARCH_UNKNOWN,
1301
  {
@@ -1333,23 +1567,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1333
  {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1334
  {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1335
  {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1336
- {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1337
- {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1338
- {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1339
- {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1340
- {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1341
- {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1342
- {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1343
- {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1344
- {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1345
- {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1346
- {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1347
- {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1348
- {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1349
- {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1350
- {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1351
- {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1352
- {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1353
  {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1354
  {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1355
  {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -1376,6 +1595,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1376
  {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1377
  {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1378
  {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
 
 
 
 
 
 
1379
  {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1380
  {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1381
  {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -1394,6 +1619,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1394
  {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1395
  {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1396
  {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 
 
 
1397
  {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1398
  {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1399
  {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@@ -1401,6 +1629,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1401
  {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1402
  {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1403
  {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
 
 
 
1404
  {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
1405
  {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1406
  {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 
6
 
7
  static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8
  { LLM_ARCH_LLAMA, "llama" },
9
+ { LLM_ARCH_LLAMA4, "llama4" },
10
  { LLM_ARCH_DECI, "deci" },
11
  { LLM_ARCH_FALCON, "falcon" },
12
  { LLM_ARCH_GROK, "grok" },
 
26
  { LLM_ARCH_QWEN2, "qwen2" },
27
  { LLM_ARCH_QWEN2MOE, "qwen2moe" },
28
  { LLM_ARCH_QWEN2VL, "qwen2vl" },
29
+ { LLM_ARCH_QWEN3, "qwen3" },
30
+ { LLM_ARCH_QWEN3MOE, "qwen3moe" },
31
  { LLM_ARCH_PHI2, "phi2" },
32
  { LLM_ARCH_PHI3, "phi3" },
33
  { LLM_ARCH_PHIMOE, "phimoe" },
 
39
  { LLM_ARCH_MINICPM3, "minicpm3" },
40
  { LLM_ARCH_GEMMA, "gemma" },
41
  { LLM_ARCH_GEMMA2, "gemma2" },
42
+ { LLM_ARCH_GEMMA3, "gemma3" },
43
  { LLM_ARCH_STARCODER2, "starcoder2" },
44
  { LLM_ARCH_MAMBA, "mamba" },
45
  { LLM_ARCH_XVERSE, "xverse" },
 
54
  { LLM_ARCH_DEEPSEEK, "deepseek" },
55
  { LLM_ARCH_DEEPSEEK2, "deepseek2" },
56
  { LLM_ARCH_CHATGLM, "chatglm" },
57
+ { LLM_ARCH_GLM4, "glm4" },
58
  { LLM_ARCH_BITNET, "bitnet" },
59
  { LLM_ARCH_T5, "t5" },
60
  { LLM_ARCH_T5ENCODER, "t5encoder" },
 
63
  { LLM_ARCH_EXAONE, "exaone" },
64
  { LLM_ARCH_RWKV6, "rwkv6" },
65
  { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
66
+ { LLM_ARCH_RWKV7, "rwkv7" },
67
+ { LLM_ARCH_ARWKV7, "arwkv7" },
68
  { LLM_ARCH_GRANITE, "granite" },
69
  { LLM_ARCH_GRANITE_MOE, "granitemoe" },
70
  { LLM_ARCH_CHAMELEON, "chameleon" },
71
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
72
+ { LLM_ARCH_PLM, "plm" },
73
+ { LLM_ARCH_BAILINGMOE, "bailingmoe" },
74
  { LLM_ARCH_UNKNOWN, "(unknown)" },
75
  };
76
 
 
79
  { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
80
  { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
81
  { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
82
+ { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
83
  { LLM_KV_GENERAL_NAME, "general.name" },
84
  { LLM_KV_GENERAL_AUTHOR, "general.author" },
85
  { LLM_KV_GENERAL_VERSION, "general.version" },
 
118
  { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
119
  { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
120
  { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
121
+ { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
122
 
123
+ { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
124
+ { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
125
+ { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
126
+ { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
127
+ { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
128
+ { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
129
+ { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
130
+ { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
131
+ { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
132
+ { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
133
+ { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
134
+ { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
135
+ { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
136
+ { LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
137
+ { LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
138
+ { LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
139
+ { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
140
+ { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
141
+ { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
142
+ { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
143
+ { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
144
+ { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
145
 
146
  { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
147
  { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
 
240
  { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
241
  },
242
  },
243
+ {
244
+ LLM_ARCH_LLAMA4,
245
+ {
246
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
247
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
248
+ { LLM_TENSOR_OUTPUT, "output" },
249
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
250
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
251
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
252
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
253
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
254
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
255
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
256
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
257
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
258
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
259
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
260
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
261
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
262
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
263
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
264
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
265
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
266
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
267
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
268
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
269
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
270
+ },
271
+ },
272
  {
273
  LLM_ARCH_DECI,
274
  {
 
600
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
601
  },
602
  },
603
+ {
604
+ LLM_ARCH_QWEN3,
605
+ {
606
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
607
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
608
+ { LLM_TENSOR_OUTPUT, "output" },
609
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
610
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
611
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
612
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
613
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
614
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
615
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
616
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
617
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
618
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
619
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
620
+ },
621
+ },
622
+ {
623
+ LLM_ARCH_QWEN3MOE,
624
+ {
625
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
626
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
627
+ { LLM_TENSOR_OUTPUT, "output" },
628
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
629
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
630
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
631
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
632
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
633
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
634
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
635
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
636
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
637
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
638
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
639
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
640
+ },
641
+ },
642
  {
643
  LLM_ARCH_PHI2,
644
  {
 
851
  { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
852
  },
853
  },
854
+ {
855
+ LLM_ARCH_GEMMA3,
856
+ {
857
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
858
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
859
+ { LLM_TENSOR_OUTPUT, "output" },
860
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
861
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
862
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
863
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
864
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
865
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
866
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
867
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
868
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
869
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
870
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
871
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
872
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
873
+ },
874
+ },
875
  {
876
  LLM_ARCH_STARCODER2,
877
  {
 
1105
  { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
1106
  { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
1107
  { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1108
+ { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1109
+ { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
1110
  { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1111
  { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1112
  { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
 
1123
  { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1124
  },
1125
  },
1126
+ {
1127
+ LLM_ARCH_PLM,
1128
+ {
1129
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1130
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1131
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1132
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1133
+ { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
1134
+ { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
1135
+ { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1136
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1137
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1138
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1139
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1140
+ },
1141
+ },
1142
  {
1143
  LLM_ARCH_CHATGLM,
1144
  {
 
1157
  { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1158
  },
1159
  },
1160
+ {
1161
+ LLM_ARCH_GLM4,
1162
+ {
1163
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1164
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1165
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1166
+ { LLM_TENSOR_OUTPUT, "output" },
1167
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1168
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1169
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1170
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1171
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1172
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1173
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1174
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1175
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1176
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1177
+ },
1178
+ },
1179
  {
1180
  LLM_ARCH_BITNET,
1181
  {
 
1360
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1361
  },
1362
  },
1363
+ {
1364
+ LLM_ARCH_RWKV7,
1365
+ {
1366
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1367
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1368
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1369
+ { LLM_TENSOR_OUTPUT, "output" },
1370
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1371
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
1372
+ { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
1373
+ { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
1374
+ { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
1375
+ { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
1376
+ { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
1377
+ { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
1378
+ { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
1379
+ { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
1380
+ { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
1381
+ { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
1382
+ { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
1383
+ { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
1384
+ { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
1385
+ { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
1386
+ { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
1387
+ { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
1388
+ { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
1389
+ { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
1390
+ { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
1391
+ { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
1392
+ { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
1393
+ { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
1394
+ { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
1395
+ },
1396
+ },
1397
+ {
1398
+ LLM_ARCH_ARWKV7,
1399
+ {
1400
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1401
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1402
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1403
+ { LLM_TENSOR_OUTPUT, "output" },
1404
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1405
+ { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
1406
+ { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
1407
+ { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
1408
+ { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
1409
+ { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
1410
+ { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
1411
+ { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
1412
+ { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
1413
+ { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
1414
+ { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
1415
+ { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
1416
+ { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
1417
+ { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
1418
+ { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
1419
+ { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
1420
+ { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
1421
+ { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
1422
+ { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
1423
+ { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
1424
+ { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
1425
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1426
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1427
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1428
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1429
+ },
1430
+ },
1431
  {
1432
  LLM_ARCH_GRANITE,
1433
  {
 
1507
  { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
1508
  },
1509
  },
1510
+ {
1511
+ LLM_ARCH_BAILINGMOE,
1512
+ {
1513
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1514
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1515
+ { LLM_TENSOR_OUTPUT, "output" },
1516
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1517
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1518
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1519
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1520
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1521
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1522
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1523
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1524
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1525
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1526
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1527
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
1528
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1529
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1530
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1531
+ },
1532
+ },
1533
  {
1534
  LLM_ARCH_UNKNOWN,
1535
  {
 
1567
  {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1568
  {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1569
  {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1570
+ {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1571
+ {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1572
  {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1573
  {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1574
  {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
 
1595
  {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1596
  {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1597
  {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1598
+ {LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1599
+ {LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1600
+ {LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1601
+ {LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1602
+ {LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1603
+ {LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1604
  {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1605
  {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1606
  {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
 
1619
  {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1620
  {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1621
  {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1622
+ {LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1623
+ {LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1624
+ {LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1625
  {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1626
  {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1627
  {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
 
1629
  {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1630
  {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1631
  {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1632
+ {LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1633
+ {LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1634
+ {LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1635
  {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
1636
  {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1637
  {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
examples/talk-llama/llama-arch.h CHANGED
@@ -10,6 +10,7 @@
10
 
11
  enum llm_arch {
12
  LLM_ARCH_LLAMA,
 
13
  LLM_ARCH_DECI,
14
  LLM_ARCH_FALCON,
15
  LLM_ARCH_BAICHUAN,
@@ -29,6 +30,8 @@ enum llm_arch {
29
  LLM_ARCH_QWEN2,
30
  LLM_ARCH_QWEN2MOE,
31
  LLM_ARCH_QWEN2VL,
 
 
32
  LLM_ARCH_PHI2,
33
  LLM_ARCH_PHI3,
34
  LLM_ARCH_PHIMOE,
@@ -40,6 +43,7 @@ enum llm_arch {
40
  LLM_ARCH_MINICPM3,
41
  LLM_ARCH_GEMMA,
42
  LLM_ARCH_GEMMA2,
 
43
  LLM_ARCH_STARCODER2,
44
  LLM_ARCH_MAMBA,
45
  LLM_ARCH_XVERSE,
@@ -54,6 +58,7 @@ enum llm_arch {
54
  LLM_ARCH_DEEPSEEK,
55
  LLM_ARCH_DEEPSEEK2,
56
  LLM_ARCH_CHATGLM,
 
57
  LLM_ARCH_BITNET,
58
  LLM_ARCH_T5,
59
  LLM_ARCH_T5ENCODER,
@@ -62,10 +67,14 @@ enum llm_arch {
62
  LLM_ARCH_EXAONE,
63
  LLM_ARCH_RWKV6,
64
  LLM_ARCH_RWKV6QWEN2,
 
 
65
  LLM_ARCH_GRANITE,
66
  LLM_ARCH_GRANITE_MOE,
67
  LLM_ARCH_CHAMELEON,
68
  LLM_ARCH_WAVTOKENIZER_DEC,
 
 
69
  LLM_ARCH_UNKNOWN,
70
  };
71
 
@@ -74,6 +83,7 @@ enum llm_kv {
74
  LLM_KV_GENERAL_ARCHITECTURE,
75
  LLM_KV_GENERAL_QUANTIZATION_VERSION,
76
  LLM_KV_GENERAL_ALIGNMENT,
 
77
  LLM_KV_GENERAL_NAME,
78
  LLM_KV_GENERAL_AUTHOR,
79
  LLM_KV_GENERAL_VERSION,
@@ -112,6 +122,7 @@ enum llm_kv {
112
  LLM_KV_RESIDUAL_SCALE,
113
  LLM_KV_EMBEDDING_SCALE,
114
  LLM_KV_TOKEN_SHIFT_COUNT,
 
115
 
116
  LLM_KV_ATTENTION_HEAD_COUNT,
117
  LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -126,9 +137,15 @@ enum llm_kv {
126
  LLM_KV_ATTENTION_CAUSAL,
127
  LLM_KV_ATTENTION_Q_LORA_RANK,
128
  LLM_KV_ATTENTION_KV_LORA_RANK,
 
 
 
 
129
  LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
130
  LLM_KV_ATTENTION_SLIDING_WINDOW,
131
  LLM_KV_ATTENTION_SCALE,
 
 
132
 
133
  LLM_KV_ROPE_DIMENSION_COUNT,
134
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -242,6 +259,8 @@ enum llm_tensor {
242
  LLM_TENSOR_ATTN_Q_NORM,
243
  LLM_TENSOR_ATTN_K_NORM,
244
  LLM_TENSOR_LAYER_OUT_NORM,
 
 
245
  LLM_TENSOR_SSM_IN,
246
  LLM_TENSOR_SSM_CONV1D,
247
  LLM_TENSOR_SSM_X,
@@ -249,8 +268,20 @@ enum llm_tensor {
249
  LLM_TENSOR_SSM_A,
250
  LLM_TENSOR_SSM_D,
251
  LLM_TENSOR_SSM_OUT,
 
252
  LLM_TENSOR_TIME_MIX_W1,
253
  LLM_TENSOR_TIME_MIX_W2,
 
 
 
 
 
 
 
 
 
 
 
254
  LLM_TENSOR_TIME_MIX_LERP_X,
255
  LLM_TENSOR_TIME_MIX_LERP_W,
256
  LLM_TENSOR_TIME_MIX_LERP_K,
@@ -277,6 +308,8 @@ enum llm_tensor {
277
  LLM_TENSOR_ATTN_Q_B,
278
  LLM_TENSOR_ATTN_KV_A_MQA,
279
  LLM_TENSOR_ATTN_KV_B,
 
 
280
  LLM_TENSOR_ATTN_Q_A_NORM,
281
  LLM_TENSOR_ATTN_KV_A_NORM,
282
  LLM_TENSOR_ATTN_SUB_NORM,
 
10
 
11
  enum llm_arch {
12
  LLM_ARCH_LLAMA,
13
+ LLM_ARCH_LLAMA4,
14
  LLM_ARCH_DECI,
15
  LLM_ARCH_FALCON,
16
  LLM_ARCH_BAICHUAN,
 
30
  LLM_ARCH_QWEN2,
31
  LLM_ARCH_QWEN2MOE,
32
  LLM_ARCH_QWEN2VL,
33
+ LLM_ARCH_QWEN3,
34
+ LLM_ARCH_QWEN3MOE,
35
  LLM_ARCH_PHI2,
36
  LLM_ARCH_PHI3,
37
  LLM_ARCH_PHIMOE,
 
43
  LLM_ARCH_MINICPM3,
44
  LLM_ARCH_GEMMA,
45
  LLM_ARCH_GEMMA2,
46
+ LLM_ARCH_GEMMA3,
47
  LLM_ARCH_STARCODER2,
48
  LLM_ARCH_MAMBA,
49
  LLM_ARCH_XVERSE,
 
58
  LLM_ARCH_DEEPSEEK,
59
  LLM_ARCH_DEEPSEEK2,
60
  LLM_ARCH_CHATGLM,
61
+ LLM_ARCH_GLM4,
62
  LLM_ARCH_BITNET,
63
  LLM_ARCH_T5,
64
  LLM_ARCH_T5ENCODER,
 
67
  LLM_ARCH_EXAONE,
68
  LLM_ARCH_RWKV6,
69
  LLM_ARCH_RWKV6QWEN2,
70
+ LLM_ARCH_RWKV7,
71
+ LLM_ARCH_ARWKV7,
72
  LLM_ARCH_GRANITE,
73
  LLM_ARCH_GRANITE_MOE,
74
  LLM_ARCH_CHAMELEON,
75
  LLM_ARCH_WAVTOKENIZER_DEC,
76
+ LLM_ARCH_PLM,
77
+ LLM_ARCH_BAILINGMOE,
78
  LLM_ARCH_UNKNOWN,
79
  };
80
 
 
83
  LLM_KV_GENERAL_ARCHITECTURE,
84
  LLM_KV_GENERAL_QUANTIZATION_VERSION,
85
  LLM_KV_GENERAL_ALIGNMENT,
86
+ LLM_KV_GENERAL_FILE_TYPE,
87
  LLM_KV_GENERAL_NAME,
88
  LLM_KV_GENERAL_AUTHOR,
89
  LLM_KV_GENERAL_VERSION,
 
122
  LLM_KV_RESIDUAL_SCALE,
123
  LLM_KV_EMBEDDING_SCALE,
124
  LLM_KV_TOKEN_SHIFT_COUNT,
125
+ LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
126
 
127
  LLM_KV_ATTENTION_HEAD_COUNT,
128
  LLM_KV_ATTENTION_HEAD_COUNT_KV,
 
137
  LLM_KV_ATTENTION_CAUSAL,
138
  LLM_KV_ATTENTION_Q_LORA_RANK,
139
  LLM_KV_ATTENTION_KV_LORA_RANK,
140
+ LLM_KV_ATTENTION_DECAY_LORA_RANK,
141
+ LLM_KV_ATTENTION_ICLR_LORA_RANK,
142
+ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
143
+ LLM_KV_ATTENTION_GATE_LORA_RANK,
144
  LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
145
  LLM_KV_ATTENTION_SLIDING_WINDOW,
146
  LLM_KV_ATTENTION_SCALE,
147
+ LLM_KV_ATTENTION_KEY_LENGTH_MLA,
148
+ LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
149
 
150
  LLM_KV_ROPE_DIMENSION_COUNT,
151
  LLM_KV_ROPE_DIMENSION_SECTIONS,
 
259
  LLM_TENSOR_ATTN_Q_NORM,
260
  LLM_TENSOR_ATTN_K_NORM,
261
  LLM_TENSOR_LAYER_OUT_NORM,
262
+ LLM_TENSOR_POST_ATTN_NORM,
263
+ LLM_TENSOR_POST_MLP_NORM,
264
  LLM_TENSOR_SSM_IN,
265
  LLM_TENSOR_SSM_CONV1D,
266
  LLM_TENSOR_SSM_X,
 
268
  LLM_TENSOR_SSM_A,
269
  LLM_TENSOR_SSM_D,
270
  LLM_TENSOR_SSM_OUT,
271
+ LLM_TENSOR_TIME_MIX_W0,
272
  LLM_TENSOR_TIME_MIX_W1,
273
  LLM_TENSOR_TIME_MIX_W2,
274
+ LLM_TENSOR_TIME_MIX_A0,
275
+ LLM_TENSOR_TIME_MIX_A1,
276
+ LLM_TENSOR_TIME_MIX_A2,
277
+ LLM_TENSOR_TIME_MIX_V0,
278
+ LLM_TENSOR_TIME_MIX_V1,
279
+ LLM_TENSOR_TIME_MIX_V2,
280
+ LLM_TENSOR_TIME_MIX_G1,
281
+ LLM_TENSOR_TIME_MIX_G2,
282
+ LLM_TENSOR_TIME_MIX_K_K,
283
+ LLM_TENSOR_TIME_MIX_K_A,
284
+ LLM_TENSOR_TIME_MIX_R_K,
285
  LLM_TENSOR_TIME_MIX_LERP_X,
286
  LLM_TENSOR_TIME_MIX_LERP_W,
287
  LLM_TENSOR_TIME_MIX_LERP_K,
 
308
  LLM_TENSOR_ATTN_Q_B,
309
  LLM_TENSOR_ATTN_KV_A_MQA,
310
  LLM_TENSOR_ATTN_KV_B,
311
+ LLM_TENSOR_ATTN_K_B,
312
+ LLM_TENSOR_ATTN_V_B,
313
  LLM_TENSOR_ATTN_Q_A_NORM,
314
  LLM_TENSOR_ATTN_KV_A_NORM,
315
  LLM_TENSOR_ATTN_SUB_NORM,
examples/talk-llama/llama-batch.h CHANGED
@@ -42,9 +42,9 @@ struct llama_sbatch {
42
  bool logits_all; // TODO: remove once lctx.logits_all is removed too
43
 
44
  // sorted indices into the batch
45
- std::vector<size_t> ids;
46
  // batch indices of the output
47
- std::vector<size_t> out_ids;
48
  std::vector<llama_sbatch_seq> seq;
49
 
50
  const llama_batch * batch = nullptr;
 
42
  bool logits_all; // TODO: remove once lctx.logits_all is removed too
43
 
44
  // sorted indices into the batch
45
+ std::vector<int64_t> ids;
46
  // batch indices of the output
47
+ std::vector<int64_t> out_ids;
48
  std::vector<llama_sbatch_seq> seq;
49
 
50
  const llama_batch * batch = nullptr;
examples/talk-llama/llama-chat.cpp CHANGED
@@ -4,6 +4,7 @@
4
 
5
  #include <map>
6
  #include <sstream>
 
7
 
8
  #if __cplusplus >= 202000L
9
  #define LU8(x) (const char*)(u8##x)
@@ -58,6 +59,10 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
58
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
59
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
60
  { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
 
 
 
 
61
  };
62
 
63
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -77,7 +82,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
77
  if (tmpl_contains("<|im_start|>")) {
78
  return tmpl_contains("<|im_sep|>")
79
  ? LLM_CHAT_TEMPLATE_PHI_4
80
- : LLM_CHAT_TEMPLATE_CHATML;
 
 
81
  } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
82
  if (tmpl_contains("[SYSTEM_PROMPT]")) {
83
  return LLM_CHAT_TEMPLATE_MISTRAL_V7;
@@ -117,6 +124,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
117
  return LLM_CHAT_TEMPLATE_PHI_3;
118
  } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
119
  return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
 
 
120
  } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
121
  return LLM_CHAT_TEMPLATE_ZEPHYR;
122
  } else if (tmpl_contains("bos_token + message['role']")) {
@@ -167,6 +176,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
167
  return LLM_CHAT_TEMPLATE_GIGACHAT;
168
  } else if (tmpl_contains("<|role_start|>")) {
169
  return LLM_CHAT_TEMPLATE_MEGREZ;
 
 
 
 
 
 
170
  }
171
  return LLM_CHAT_TEMPLATE_UNKNOWN;
172
  }
@@ -566,6 +581,66 @@ int32_t llm_chat_apply_template(
566
  if (add_ass) {
567
  ss << "<|role_start|>assistant<|role_end|>";
568
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  } else {
570
  // template not supported
571
  return -1;
@@ -584,4 +659,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
584
  }
585
  return (int32_t) LLM_CHAT_TEMPLATES.size();
586
  }
587
-
 
4
 
5
  #include <map>
6
  #include <sstream>
7
+ #include <algorithm>
8
 
9
  #if __cplusplus >= 202000L
10
  #define LU8(x) (const char*)(u8##x)
 
59
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
60
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
61
  { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
62
+ { "yandex", LLM_CHAT_TEMPLATE_YANDEX },
63
+ { "bailing", LLM_CHAT_TEMPLATE_BAILING },
64
+ { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
65
+ { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
66
  };
67
 
68
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
 
82
  if (tmpl_contains("<|im_start|>")) {
83
  return tmpl_contains("<|im_sep|>")
84
  ? LLM_CHAT_TEMPLATE_PHI_4
85
+ : tmpl_contains("<end_of_utterance>")
86
+ ? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
87
+ : LLM_CHAT_TEMPLATE_CHATML;
88
  } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
89
  if (tmpl_contains("[SYSTEM_PROMPT]")) {
90
  return LLM_CHAT_TEMPLATE_MISTRAL_V7;
 
124
  return LLM_CHAT_TEMPLATE_PHI_3;
125
  } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
126
  return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
127
+ } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
128
+ return LLM_CHAT_TEMPLATE_GLMEDGE;
129
  } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
130
  return LLM_CHAT_TEMPLATE_ZEPHYR;
131
  } else if (tmpl_contains("bos_token + message['role']")) {
 
176
  return LLM_CHAT_TEMPLATE_GIGACHAT;
177
  } else if (tmpl_contains("<|role_start|>")) {
178
  return LLM_CHAT_TEMPLATE_MEGREZ;
179
+ } else if (tmpl_contains(" Ассистент:")) {
180
+ return LLM_CHAT_TEMPLATE_YANDEX;
181
+ } else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
182
+ return LLM_CHAT_TEMPLATE_BAILING;
183
+ } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
184
+ return LLM_CHAT_TEMPLATE_LLAMA4;
185
  }
186
  return LLM_CHAT_TEMPLATE_UNKNOWN;
187
  }
 
581
  if (add_ass) {
582
  ss << "<|role_start|>assistant<|role_end|>";
583
  }
584
+ } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
585
+ // Yandex template ("\n\n" is defined as EOT token)
586
+
587
+ ss << "<s>";
588
+
589
+ for (size_t i = 0; i < chat.size(); i++) {
590
+ std::string role(chat[i]->role);
591
+ if (role == "user") {
592
+ ss << " Пользователь: " << chat[i]->content << "\n\n";
593
+ } else if (role == "assistant") {
594
+ ss << " Ассистент: " << chat[i]->content << "\n\n";
595
+ }
596
+ }
597
+
598
+ // Add generation prompt if needed
599
+ if (add_ass) {
600
+ ss << " Ассистент:[SEP]";
601
+ }
602
+ } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
603
+ // Bailing (Ling) template
604
+ for (auto message : chat) {
605
+ std::string role(message->role);
606
+
607
+ if (role == "user") {
608
+ role = "HUMAN";
609
+ } else {
610
+ std::transform(role.begin(), role.end(), role.begin(), ::toupper);
611
+ }
612
+
613
+ ss << "<role>" << role << "</role>" << message->content;
614
+ }
615
+
616
+ if (add_ass) {
617
+ ss << "<role>ASSISTANT</role>";
618
+ }
619
+ } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
620
+ // Llama 4
621
+ for (auto message : chat) {
622
+ std::string role(message->role);
623
+ ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
624
+ }
625
+ if (add_ass) {
626
+ ss << "<|header_start|>assistant<|header_end|>\n\n";
627
+ }
628
+ } else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
629
+ // SmolVLM
630
+ ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
631
+ for (auto message : chat) {
632
+ std::string role(message->role);
633
+ if (role == "system") {
634
+ ss << message->content << "\n\n";
635
+ } else if (role == "user") {
636
+ ss << "User: " << message->content << "<end_of_utterance>\n";
637
+ } else {
638
+ ss << "Assistant: " << message->content << "<end_of_utterance>\n";
639
+ }
640
+ }
641
+ if (add_ass) {
642
+ ss << "Assistant:";
643
+ }
644
  } else {
645
  // template not supported
646
  return -1;
 
659
  }
660
  return (int32_t) LLM_CHAT_TEMPLATES.size();
661
  }
 
examples/talk-llama/llama-chat.h CHANGED
@@ -38,6 +38,10 @@ enum llm_chat_template {
38
  LLM_CHAT_TEMPLATE_GRANITE,
39
  LLM_CHAT_TEMPLATE_GIGACHAT,
40
  LLM_CHAT_TEMPLATE_MEGREZ,
 
 
 
 
41
  LLM_CHAT_TEMPLATE_UNKNOWN,
42
  };
43
 
 
38
  LLM_CHAT_TEMPLATE_GRANITE,
39
  LLM_CHAT_TEMPLATE_GIGACHAT,
40
  LLM_CHAT_TEMPLATE_MEGREZ,
41
+ LLM_CHAT_TEMPLATE_YANDEX,
42
+ LLM_CHAT_TEMPLATE_BAILING,
43
+ LLM_CHAT_TEMPLATE_LLAMA4,
44
+ LLM_CHAT_TEMPLATE_SMOLVLM,
45
  LLM_CHAT_TEMPLATE_UNKNOWN,
46
  };
47
 
examples/talk-llama/llama-context.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama-context.h CHANGED
@@ -3,66 +3,213 @@
3
  #include "llama.h"
4
  #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
- #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
  #include "llama-adapter.h"
9
 
10
  #include "ggml-cpp.h"
11
 
12
  #include <map>
13
- #include <unordered_map>
14
  #include <vector>
15
- #include <set>
 
 
 
 
 
16
 
17
  struct llama_context {
18
- llama_context(const llama_model & model)
19
- : model(model)
20
- , t_start_us(model.t_start_us)
21
- , t_load_us(model.t_load_us) {}
22
 
23
- const struct llama_model & model;
24
 
25
- struct llama_cparams cparams;
26
- struct llama_sbatch sbatch; // TODO: revisit if needed
27
- struct llama_kv_cache kv_self;
28
- struct llama_adapter_cvec cvec;
29
 
30
- std::unordered_map<struct llama_adapter_lora *, float> lora;
31
 
32
- std::vector<ggml_backend_ptr> backends;
33
- std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
 
 
 
34
 
35
- ggml_backend_t backend_cpu = nullptr;
 
36
 
37
- ggml_threadpool_t threadpool = nullptr;
38
- ggml_threadpool_t threadpool_batch = nullptr;
39
 
40
- bool has_evaluated_once = false;
41
 
42
- mutable int64_t t_start_us;
43
- mutable int64_t t_load_us;
44
- mutable int64_t t_p_eval_us = 0;
45
- mutable int64_t t_eval_us = 0;
46
 
47
- mutable int64_t t_compute_start_us = 0;
48
- mutable int64_t n_queued_tokens = 0;
49
 
50
- mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
51
- mutable int32_t n_eval = 0; // number of eval calls
 
52
 
53
- // host buffer for the model output (logits and embeddings)
54
- ggml_backend_buffer_ptr buf_output;
 
55
 
56
- // decode output (2-dimensional array: [n_outputs][n_vocab])
57
- size_t logits_size = 0; // capacity (of floats) for logits
58
- float * logits = nullptr;
59
 
60
- std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
61
- size_t output_size = 0; // capacity (of tokens positions) for the output buffers
62
- int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  bool logits_all = false;
65
 
 
 
 
 
66
  // embeddings output (2-dimensional array: [n_outputs][n_embd])
67
  // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
68
  size_t embd_size = 0; // capacity (of floats) for embeddings
@@ -72,57 +219,47 @@ struct llama_context {
72
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
73
  std::map<llama_seq_id, std::vector<float>> embd_seq;
74
 
75
- // whether we are computing encoder output or decoder output
76
- bool is_encoding = false;
77
-
78
- // TODO: find a better way to accommodate mutli-dimension position encoding methods
79
- // number of position id each token get, 1 for each token in most cases.
80
- // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
81
- int n_pos_per_token = 1;
82
 
83
- // output of the encoder part of the encoder-decoder models
84
- std::vector<float> embd_enc;
85
- std::vector<std::set<llama_seq_id>> seq_ids_enc;
86
 
87
- // memory buffers used to evaluate the model
88
- std::vector<uint8_t> buf_compute_meta;
89
  ggml_backend_sched_ptr sched;
90
 
 
 
 
 
 
 
 
 
91
  ggml_abort_callback abort_callback = nullptr;
92
  void * abort_callback_data = nullptr;
93
 
94
- // input tensors
95
- struct ggml_tensor * inp_tokens; // I32 [n_batch]
96
- struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
97
- struct ggml_tensor * inp_pos; // I32 [n_batch]
98
- struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
99
- struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
100
- struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
101
- struct ggml_tensor * inp_K_shift; // I32 [kv_size]
102
- struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
103
- struct ggml_tensor * inp_cls; // I32 [n_batch]
104
- struct ggml_tensor * inp_s_copy; // I32 [kv_size]
105
- struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
106
- struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
107
- struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
108
- struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
109
- struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
110
- };
111
 
112
- // TODO: make these methods of llama_context
113
- void llama_set_k_shift(struct llama_context & lctx);
114
 
115
- void llama_set_s_copy(struct llama_context & lctx);
 
116
 
117
- void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
118
 
119
- // Make sure enough space is available for outputs.
120
- // Returns max number of outputs for which space was reserved.
121
- size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
 
 
122
 
123
- // make the outputs have the same order they had in the user-provided batch
124
- void llama_output_reorder(struct llama_context & ctx);
125
 
126
- // For internal test use
127
- // TODO: remove
128
- const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
 
3
  #include "llama.h"
4
  #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
+ #include "llama-graph.h"
 
7
  #include "llama-adapter.h"
8
 
9
  #include "ggml-cpp.h"
10
 
11
  #include <map>
 
12
  #include <vector>
13
+
14
+ struct llama_model;
15
+ struct llama_kv_cache;
16
+
17
+ class llama_io_read_i;
18
+ class llama_io_write_i;
19
 
20
  struct llama_context {
21
+ // init scheduler and compute buffers, reserve worst-case graphs
22
+ llama_context(
23
+ const llama_model & model,
24
+ llama_context_params params);
25
 
26
+ ~llama_context();
27
 
28
+ void synchronize();
 
 
 
29
 
30
+ const llama_model & get_model() const;
31
 
32
+ uint32_t n_ctx() const;
33
+ uint32_t n_ctx_per_seq() const;
34
+ uint32_t n_batch() const;
35
+ uint32_t n_ubatch() const;
36
+ uint32_t n_seq_max() const;
37
 
38
+ uint32_t n_threads() const;
39
+ uint32_t n_threads_batch() const;
40
 
41
+ llama_kv_cache * get_kv_self();
42
+ const llama_kv_cache * get_kv_self() const;
43
 
44
+ void kv_self_update();
45
 
46
+ enum llama_pooling_type pooling_type() const;
 
 
 
47
 
48
+ float * get_logits();
49
+ float * get_logits_ith(int32_t i);
50
 
51
+ float * get_embeddings();
52
+ float * get_embeddings_ith(int32_t i);
53
+ float * get_embeddings_seq(llama_seq_id seq_id);
54
 
55
+ void attach_threadpool(
56
+ ggml_threadpool_t threadpool,
57
+ ggml_threadpool_t threadpool_batch);
58
 
59
+ void detach_threadpool();
 
 
60
 
61
+ void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
62
+
63
+ void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
64
+
65
+ void set_embeddings (bool value);
66
+ void set_causal_attn(bool value);
67
+ void set_warmup(bool value);
68
+
69
+ void set_adapter_lora(
70
+ llama_adapter_lora * adapter,
71
+ float scale);
72
+
73
+ bool rm_adapter_lora(
74
+ llama_adapter_lora * adapter);
75
+
76
+ void clear_adapter_lora();
77
+
78
+ bool apply_adapter_cvec(
79
+ const float * data,
80
+ size_t len,
81
+ int32_t n_embd,
82
+ int32_t il_start,
83
+ int32_t il_end);
84
+
85
+ int encode(llama_batch & inp_batch);
86
+ int decode(llama_batch & inp_batch);
87
+
88
+ //
89
+ // state save/load
90
+ //
91
+
92
+ size_t state_get_size();
93
+ size_t state_get_data( uint8_t * dst, size_t size);
94
+ size_t state_set_data(const uint8_t * src, size_t size);
95
+
96
+ size_t state_seq_get_size(llama_seq_id seq_id);
97
+ size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
98
+ size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
99
+
100
+ bool state_load_file(
101
+ const char * filepath,
102
+ llama_token * tokens_out,
103
+ size_t n_token_capacity,
104
+ size_t * n_token_count_out);
105
+
106
+ bool state_save_file(
107
+ const char * filepath,
108
+ const llama_token * tokens,
109
+ size_t n_token_count);
110
+
111
+ size_t state_seq_load_file(
112
+ llama_seq_id seq_id,
113
+ const char * filepath,
114
+ llama_token * tokens_out,
115
+ size_t n_token_capacity,
116
+ size_t * n_token_count_out);
117
+
118
+ size_t state_seq_save_file(
119
+ llama_seq_id seq_id,
120
+ const char * filepath,
121
+ const llama_token * tokens,
122
+ size_t n_token_count);
123
+
124
+ //
125
+ // perf
126
+ //
127
+
128
+ llama_perf_context_data perf_get_data() const;
129
+ void perf_reset();
130
+
131
+ private:
132
+ //
133
+ // output
134
+ //
135
 
136
+ // Make sure enough space is available for outputs.
137
+ // Returns max number of outputs for which space was reserved.
138
+ int32_t output_reserve(int32_t n_outputs);
139
+
140
+ // make the outputs have the same order they had in the user-provided batch
141
+ // TODO: maybe remove this
142
+ void output_reorder();
143
+
144
+ //
145
+ // graph
146
+ //
147
+
148
+ int32_t graph_max_nodes() const;
149
+
150
+ // zero-out inputs and create the ctx_compute for the compute graph
151
+ ggml_cgraph * graph_init();
152
+
153
+ llm_graph_result_ptr graph_build(
154
+ ggml_context * ctx,
155
+ ggml_cgraph * gf,
156
+ const llama_ubatch & ubatch,
157
+ llm_graph_type gtype);
158
+
159
+ // returns the result of ggml_backend_sched_graph_compute_async execution
160
+ ggml_status graph_compute(
161
+ ggml_cgraph * gf,
162
+ bool batched);
163
+
164
+ llm_graph_cb graph_get_cb() const;
165
+
166
+ // used by kv_self_update()
167
+ ggml_tensor * build_rope_shift(
168
+ ggml_context * ctx0,
169
+ ggml_tensor * cur,
170
+ ggml_tensor * shift,
171
+ ggml_tensor * factors,
172
+ float freq_base,
173
+ float freq_scale,
174
+ ggml_backend_buffer * bbuf) const;
175
+
176
+ llm_graph_result_ptr build_kv_self_shift(
177
+ ggml_context * ctx0,
178
+ ggml_cgraph * gf) const;
179
+
180
+ llm_graph_result_ptr build_kv_self_defrag(
181
+ ggml_context * ctx0,
182
+ ggml_cgraph * gf) const;
183
+
184
+ // TODO: read/write lora adapters and cvec
185
+ size_t state_write_data(llama_io_write_i & io);
186
+ size_t state_read_data (llama_io_read_i & io);
187
+
188
+ size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
189
+ size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
190
+
191
+ //
192
+ // members
193
+ //
194
+
195
+ const llama_model & model;
196
+
197
+ llama_cparams cparams;
198
+ llama_adapter_cvec cvec;
199
+ llama_adapter_loras loras;
200
+ llama_sbatch sbatch;
201
+
202
+ llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203
+
204
+ std::unique_ptr<llama_kv_cache_unified> kv_self;
205
+
206
+ // TODO: remove
207
  bool logits_all = false;
208
 
209
+ // decode output (2-dimensional array: [n_outputs][n_vocab])
210
+ size_t logits_size = 0; // capacity (of floats) for logits
211
+ float * logits = nullptr;
212
+
213
  // embeddings output (2-dimensional array: [n_outputs][n_embd])
214
  // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
215
  size_t embd_size = 0; // capacity (of floats) for embeddings
 
219
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
220
  std::map<llama_seq_id, std::vector<float>> embd_seq;
221
 
222
+ int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
223
+ int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
 
 
 
 
 
224
 
225
+ std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 
 
226
 
 
 
227
  ggml_backend_sched_ptr sched;
228
 
229
+ ggml_backend_t backend_cpu = nullptr;
230
+ std::vector<ggml_backend_ptr> backends;
231
+
232
+ ggml_context_ptr ctx_compute;
233
+
234
+ ggml_threadpool_t threadpool = nullptr;
235
+ ggml_threadpool_t threadpool_batch = nullptr;
236
+
237
  ggml_abort_callback abort_callback = nullptr;
238
  void * abort_callback_data = nullptr;
239
 
240
+ std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
241
+
242
+ // buffer types used for the compute buffer of each backend
243
+ std::vector<ggml_backend_t> backend_ptrs;
244
+ std::vector<ggml_backend_buffer_type_t> backend_buft;
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ // memory buffers used to evaluate the model
247
+ std::vector<uint8_t> buf_compute_meta;
248
 
249
+ // host buffer for the model output (logits and embeddings)
250
+ ggml_backend_buffer_ptr buf_output;
251
 
252
+ bool has_evaluated_once = false;
253
 
254
+ // perf
255
+ mutable int64_t t_start_us = 0;
256
+ mutable int64_t t_load_us = 0;
257
+ mutable int64_t t_p_eval_us = 0;
258
+ mutable int64_t t_eval_us = 0;
259
 
260
+ mutable int64_t t_compute_start_us = 0;
261
+ mutable int64_t n_queued_tokens = 0;
262
 
263
+ mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
264
+ mutable int32_t n_eval = 0; // number of eval calls
265
+ };
examples/talk-llama/llama-cparams.h CHANGED
@@ -29,6 +29,7 @@ struct llama_cparams {
29
  bool offload_kqv;
30
  bool flash_attn;
31
  bool no_perf;
 
32
 
33
  enum llama_pooling_type pooling_type;
34
 
 
29
  bool offload_kqv;
30
  bool flash_attn;
31
  bool no_perf;
32
+ bool warmup;
33
 
34
  enum llama_pooling_type pooling_type;
35
 
examples/talk-llama/llama-grammar.cpp CHANGED
@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
345
  size_t last_sym_start = rule.size();
346
  const char * pos = src;
347
 
348
- auto handle_repetitions = [&](int min_times, int max_times) {
349
 
350
- if (last_sym_start == rule.size()) {
351
- throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
352
- }
353
 
354
- // apply transformation to previous symbol (last_sym_start to end) according to
355
- // the following rewrite rules:
356
- // S{m,n} --> S S S (m times) S'(n-m)
357
- // S'(x) ::= S S'(x-1) |
358
- // (... n-m definitions of these S' rules ...)
359
- // S'(1) ::= S |
360
- // S{m,} --> S S S (m times) S'
361
- // S' ::= S S' |
362
- // S* --> S{0,}
363
- // --> S' ::= S S' |
364
- // S+ --> S{1,}
365
- // --> S S'
366
- // S' ::= S S' |
367
- // S? --> S{0,1}
368
- // --> S'
369
- // S' ::= S |
370
-
371
- llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
372
- if (min_times == 0) {
373
- rule.resize(last_sym_start);
374
- } else {
375
- // Repeat the previous elements (min_times - 1) times
376
- for (int i = 1; i < min_times; i++) {
377
- rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
378
- }
379
  }
 
380
 
381
- uint32_t last_rec_rule_id = 0;
382
- auto n_opt = max_times < 0 ? 1 : max_times - min_times;
383
 
384
- llama_grammar_rule rec_rule(prev_rule);
385
- for (int i = 0; i < n_opt; i++) {
386
- rec_rule.resize(prev_rule.size());
387
- uint32_t rec_rule_id = generate_symbol_id( rule_name);
388
- if (i > 0 || max_times < 0) {
389
- rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
390
- }
391
- rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
392
- rec_rule.push_back({LLAMA_GRETYPE_END, 0});
393
- add_rule( rec_rule_id, rec_rule);
394
- last_rec_rule_id = rec_rule_id;
395
  }
396
- if (n_opt > 0) {
397
- rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
398
- }
399
- };
 
 
 
 
 
400
 
401
- while (*pos) {
402
- if (*pos == '"') { // literal string
403
- pos++;
404
- last_sym_start = rule.size();
405
- while (*pos != '"') {
406
- if (!*pos) {
407
- throw std::runtime_error("unexpected end of input");
408
- }
409
- auto char_pair = parse_char(pos);
410
- pos = char_pair.second;
411
- rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
412
  }
413
- pos = parse_space(pos + 1, is_nested);
414
- } else if (*pos == '[') { // char range(s)
 
 
 
 
 
 
 
415
  pos++;
416
- enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
417
- if (*pos == '^') {
418
- pos++;
419
- start_type = LLAMA_GRETYPE_CHAR_NOT;
 
 
420
  }
421
- last_sym_start = rule.size();
422
- while (*pos != ']') {
423
- if (!*pos) {
 
 
 
 
 
 
424
  throw std::runtime_error("unexpected end of input");
425
  }
426
- auto char_pair = parse_char(pos);
427
- pos = char_pair.second;
428
- enum llama_gretype type = last_sym_start < rule.size()
429
- ? LLAMA_GRETYPE_CHAR_ALT
430
- : start_type;
431
-
432
- rule.push_back({type, char_pair.first});
433
- if (pos[0] == '-' && pos[1] != ']') {
434
- if (!pos[1]) {
435
- throw std::runtime_error("unexpected end of input");
436
- }
437
- auto endchar_pair = parse_char(pos + 1);
438
- pos = endchar_pair.second;
439
- rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
440
- }
441
- }
442
- pos = parse_space(pos + 1, is_nested);
443
- } else if (is_word_char(*pos)) { // rule reference
444
- const char * name_end = parse_name(pos);
445
- uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
446
- pos = parse_space(name_end, is_nested);
447
- last_sym_start = rule.size();
448
- rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
449
- } else if (*pos == '(') { // grouping
450
- // parse nested alternates into synthesized rule
451
- pos = parse_space(pos + 1, true);
452
- uint32_t sub_rule_id = generate_symbol_id(rule_name);
453
- pos = parse_alternates(pos, rule_name, sub_rule_id, true);
454
- last_sym_start = rule.size();
455
- // output reference to synthesized rule
456
- rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
457
- if (*pos != ')') {
458
- throw std::runtime_error(std::string("expecting ')' at ") + pos);
459
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  pos = parse_space(pos + 1, is_nested);
461
- } else if (*pos == '.') { // any char
462
- last_sym_start = rule.size();
463
- rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
464
- pos = parse_space(pos + 1, is_nested);
465
- } else if (*pos == '*') {
466
- pos = parse_space(pos + 1, is_nested);
467
- handle_repetitions(0, -1);
468
- } else if (*pos == '+') {
469
- pos = parse_space(pos + 1, is_nested);
470
- handle_repetitions(1, -1);
471
- } else if (*pos == '?') {
472
- pos = parse_space(pos + 1, is_nested);
473
- handle_repetitions(0, 1);
474
- } else if (*pos == '{') {
475
  pos = parse_space(pos + 1, is_nested);
476
 
477
- if (!is_digit_char(*pos)) {
478
- throw std::runtime_error(std::string("expecting an int at ") + pos);
 
 
479
  }
480
- const char * int_end = parse_int(pos);
481
- int min_times = std::stoul(std::string(pos, int_end - pos));
482
- pos = parse_space(int_end, is_nested);
483
-
484
- int max_times = -1;
485
-
486
- if (*pos == '}') {
487
- max_times = min_times;
488
- pos = parse_space(pos + 1, is_nested);
489
- } else if (*pos == ',') {
490
- pos = parse_space(pos + 1, is_nested);
491
-
492
- if (is_digit_char(*pos)) {
493
- const char * int_end = parse_int(pos);
494
- max_times = std::stoul(std::string(pos, int_end - pos));
495
- pos = parse_space(int_end, is_nested);
496
- }
497
 
498
- if (*pos != '}') {
499
- throw std::runtime_error(std::string("expecting '}' at ") + pos);
500
- }
501
- pos = parse_space(pos + 1, is_nested);
502
- } else {
503
- throw std::runtime_error(std::string("expecting ',' at ") + pos);
504
  }
505
- handle_repetitions(min_times, max_times);
506
  } else {
507
- break;
508
  }
 
 
 
509
  }
510
- return pos;
511
  }
 
 
512
 
513
  const char * llama_grammar_parser::parse_rule(const char * src) {
514
- const char * name_end = parse_name(src);
515
- const char * pos = parse_space(name_end, false);
516
- size_t name_len = name_end - src;
517
- uint32_t rule_id = get_symbol_id(src, name_len);
518
- const std::string name(src, name_len);
519
-
520
- if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
521
- throw std::runtime_error(std::string("expecting ::= at ") + pos);
522
- }
523
- pos = parse_space(pos + 3, true);
524
 
525
- pos = parse_alternates(pos, name, rule_id, false);
526
 
527
- if (*pos == '\r') {
528
- pos += pos[1] == '\n' ? 2 : 1;
529
- } else if (*pos == '\n') {
530
- pos++;
531
- } else if (*pos) {
532
- throw std::runtime_error(std::string("expecting newline or end at ") + pos);
533
- }
534
- return parse_space(pos, true);
535
  }
 
 
536
 
537
  bool llama_grammar_parser::parse(const char * src) {
538
  try {
@@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
969
  /* .awaiting_trigger = */ false,
970
  /* .trigger_buffer = */ "",
971
  /* .trigger_tokens = */ {},
972
- /* .trigger_words = */ {},
973
  };
974
  }
975
 
@@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl(
978
  const char * grammar_str,
979
  const char * grammar_root,
980
  bool lazy,
981
- const char ** trigger_words,
982
- size_t num_trigger_words,
983
  const llama_token * trigger_tokens,
984
  size_t num_trigger_tokens) {
985
  llama_grammar_parser parser;
986
 
987
  // if there is a grammar, parse it
988
- if (!parser.parse(grammar_str)) {
989
- return nullptr;
990
- }
991
-
992
- // will be empty (default) if there are parse errors
993
- if (parser.rules.empty()) {
994
  fprintf(stderr, "%s: failed to parse grammar\n", __func__);
995
  return nullptr;
996
  }
@@ -1054,14 +1050,16 @@ struct llama_grammar * llama_grammar_init_impl(
1054
  } while (true);
1055
 
1056
  std::vector<llama_token> vec_trigger_tokens;
1057
- std::vector<std::string> vec_trigger_words;
1058
  for (size_t i = 0; i < num_trigger_tokens; i++) {
1059
  GGML_ASSERT(trigger_tokens != nullptr);
1060
  vec_trigger_tokens.push_back(trigger_tokens[i]);
1061
  }
1062
- for (size_t i = 0; i < num_trigger_words; i++) {
1063
- GGML_ASSERT(trigger_words != nullptr);
1064
- vec_trigger_words.push_back(trigger_words[i]);
 
 
1065
  }
1066
 
1067
  // Important: vec_rules has to be moved here, not copied, because stacks contains
@@ -1076,7 +1074,7 @@ struct llama_grammar * llama_grammar_init_impl(
1076
  /* .awaiting_trigger = */ lazy,
1077
  /* .trigger_buffer = */ "",
1078
  std::move(vec_trigger_tokens),
1079
- std::move(vec_trigger_words),
1080
  };
1081
  }
1082
 
@@ -1089,7 +1087,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
1089
  }
1090
 
1091
  struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1092
- llama_grammar * result = new llama_grammar {
1093
  grammar.vocab,
1094
  grammar.rules,
1095
  grammar.stacks,
@@ -1098,7 +1096,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
1098
  grammar.awaiting_trigger,
1099
  grammar.trigger_buffer,
1100
  grammar.trigger_tokens,
1101
- grammar.trigger_words,
1102
  };
1103
 
1104
  // redirect elements in stacks to point to new rules
@@ -1173,20 +1171,22 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1173
  LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1174
  return;
1175
  } else {
1176
- // TODO: consider a smarter incremental substring search algorithm (store last position to search from).
1177
  grammar.trigger_buffer += piece;
1178
- for (const auto & word : grammar.trigger_words) {
1179
- auto pos = grammar.trigger_buffer.find(word);
1180
- if (pos != std::string::npos) {
 
1181
  grammar.awaiting_trigger = false;
1182
- auto constrained_str = grammar.trigger_buffer.substr(pos);
 
 
1183
  grammar.trigger_buffer.clear();
1184
  llama_grammar_accept_str(grammar, constrained_str);
1185
- LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
1186
  return;
1187
  }
1188
  }
1189
- LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
1190
  return;
1191
  }
1192
  }
 
345
  size_t last_sym_start = rule.size();
346
  const char * pos = src;
347
 
348
+ auto handle_repetitions = [&](int min_times, int max_times) {
349
 
350
+ if (last_sym_start == rule.size()) {
351
+ throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
352
+ }
353
 
354
+ // apply transformation to previous symbol (last_sym_start to end) according to
355
+ // the following rewrite rules:
356
+ // S{m,n} --> S S S (m times) S'(n-m)
357
+ // S'(x) ::= S S'(x-1) |
358
+ // (... n-m definitions of these S' rules ...)
359
+ // S'(1) ::= S |
360
+ // S{m,} --> S S S (m times) S'
361
+ // S' ::= S S' |
362
+ // S* --> S{0,}
363
+ // --> S' ::= S S' |
364
+ // S+ --> S{1,}
365
+ // --> S S'
366
+ // S' ::= S S' |
367
+ // S? --> S{0,1}
368
+ // --> S'
369
+ // S' ::= S |
370
+
371
+ llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
372
+ if (min_times == 0) {
373
+ rule.resize(last_sym_start);
374
+ } else {
375
+ // Repeat the previous elements (min_times - 1) times
376
+ for (int i = 1; i < min_times; i++) {
377
+ rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
 
378
  }
379
+ }
380
 
381
+ uint32_t last_rec_rule_id = 0;
382
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
383
 
384
+ llama_grammar_rule rec_rule(prev_rule);
385
+ for (int i = 0; i < n_opt; i++) {
386
+ rec_rule.resize(prev_rule.size());
387
+ uint32_t rec_rule_id = generate_symbol_id( rule_name);
388
+ if (i > 0 || max_times < 0) {
389
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
 
 
 
 
 
390
  }
391
+ rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
392
+ rec_rule.push_back({LLAMA_GRETYPE_END, 0});
393
+ add_rule( rec_rule_id, rec_rule);
394
+ last_rec_rule_id = rec_rule_id;
395
+ }
396
+ if (n_opt > 0) {
397
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
398
+ }
399
+ };
400
 
401
+ while (*pos) {
402
+ if (*pos == '"') { // literal string
403
+ pos++;
404
+ last_sym_start = rule.size();
405
+ while (*pos != '"') {
406
+ if (!*pos) {
407
+ throw std::runtime_error("unexpected end of input");
 
 
 
 
408
  }
409
+ auto char_pair = parse_char(pos);
410
+ pos = char_pair.second;
411
+ rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
412
+ }
413
+ pos = parse_space(pos + 1, is_nested);
414
+ } else if (*pos == '[') { // char range(s)
415
+ pos++;
416
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
417
+ if (*pos == '^') {
418
  pos++;
419
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
420
+ }
421
+ last_sym_start = rule.size();
422
+ while (*pos != ']') {
423
+ if (!*pos) {
424
+ throw std::runtime_error("unexpected end of input");
425
  }
426
+ auto char_pair = parse_char(pos);
427
+ pos = char_pair.second;
428
+ enum llama_gretype type = last_sym_start < rule.size()
429
+ ? LLAMA_GRETYPE_CHAR_ALT
430
+ : start_type;
431
+
432
+ rule.push_back({type, char_pair.first});
433
+ if (pos[0] == '-' && pos[1] != ']') {
434
+ if (!pos[1]) {
435
  throw std::runtime_error("unexpected end of input");
436
  }
437
+ auto endchar_pair = parse_char(pos + 1);
438
+ pos = endchar_pair.second;
439
+ rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  }
441
+ }
442
+ pos = parse_space(pos + 1, is_nested);
443
+ } else if (is_word_char(*pos)) { // rule reference
444
+ const char * name_end = parse_name(pos);
445
+ uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
446
+ pos = parse_space(name_end, is_nested);
447
+ last_sym_start = rule.size();
448
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
449
+ } else if (*pos == '(') { // grouping
450
+ // parse nested alternates into synthesized rule
451
+ pos = parse_space(pos + 1, true);
452
+ uint32_t sub_rule_id = generate_symbol_id(rule_name);
453
+ pos = parse_alternates(pos, rule_name, sub_rule_id, true);
454
+ last_sym_start = rule.size();
455
+ // output reference to synthesized rule
456
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
457
+ if (*pos != ')') {
458
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
459
+ }
460
+ pos = parse_space(pos + 1, is_nested);
461
+ } else if (*pos == '.') { // any char
462
+ last_sym_start = rule.size();
463
+ rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
464
+ pos = parse_space(pos + 1, is_nested);
465
+ } else if (*pos == '*') {
466
+ pos = parse_space(pos + 1, is_nested);
467
+ handle_repetitions(0, -1);
468
+ } else if (*pos == '+') {
469
+ pos = parse_space(pos + 1, is_nested);
470
+ handle_repetitions(1, -1);
471
+ } else if (*pos == '?') {
472
+ pos = parse_space(pos + 1, is_nested);
473
+ handle_repetitions(0, 1);
474
+ } else if (*pos == '{') {
475
+ pos = parse_space(pos + 1, is_nested);
476
+
477
+ if (!is_digit_char(*pos)) {
478
+ throw std::runtime_error(std::string("expecting an int at ") + pos);
479
+ }
480
+ const char * int_end = parse_int(pos);
481
+ int min_times = std::stoul(std::string(pos, int_end - pos));
482
+ pos = parse_space(int_end, is_nested);
483
+
484
+ int max_times = -1;
485
+
486
+ if (*pos == '}') {
487
+ max_times = min_times;
488
  pos = parse_space(pos + 1, is_nested);
489
+ } else if (*pos == ',') {
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  pos = parse_space(pos + 1, is_nested);
491
 
492
+ if (is_digit_char(*pos)) {
493
+ const char * int_end = parse_int(pos);
494
+ max_times = std::stoul(std::string(pos, int_end - pos));
495
+ pos = parse_space(int_end, is_nested);
496
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
+ if (*pos != '}') {
499
+ throw std::runtime_error(std::string("expecting '}' at ") + pos);
 
 
 
 
500
  }
501
+ pos = parse_space(pos + 1, is_nested);
502
  } else {
503
+ throw std::runtime_error(std::string("expecting ',' at ") + pos);
504
  }
505
+ handle_repetitions(min_times, max_times);
506
+ } else {
507
+ break;
508
  }
 
509
  }
510
+ return pos;
511
+ }
512
 
513
  const char * llama_grammar_parser::parse_rule(const char * src) {
514
+ const char * name_end = parse_name(src);
515
+ const char * pos = parse_space(name_end, false);
516
+ size_t name_len = name_end - src;
517
+ uint32_t rule_id = get_symbol_id(src, name_len);
518
+ const std::string name(src, name_len);
519
+
520
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
521
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
522
+ }
523
+ pos = parse_space(pos + 3, true);
524
 
525
+ pos = parse_alternates(pos, name, rule_id, false);
526
 
527
+ if (*pos == '\r') {
528
+ pos += pos[1] == '\n' ? 2 : 1;
529
+ } else if (*pos == '\n') {
530
+ pos++;
531
+ } else if (*pos) {
532
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
 
 
533
  }
534
+ return parse_space(pos, true);
535
+ }
536
 
537
  bool llama_grammar_parser::parse(const char * src) {
538
  try {
 
969
  /* .awaiting_trigger = */ false,
970
  /* .trigger_buffer = */ "",
971
  /* .trigger_tokens = */ {},
972
+ /* .trigger_patterns = */ {},
973
  };
974
  }
975
 
 
978
  const char * grammar_str,
979
  const char * grammar_root,
980
  bool lazy,
981
+ const char ** trigger_patterns,
982
+ size_t num_trigger_patterns,
983
  const llama_token * trigger_tokens,
984
  size_t num_trigger_tokens) {
985
  llama_grammar_parser parser;
986
 
987
  // if there is a grammar, parse it
988
+ // rules will be empty (default) if there are parse errors
989
+ if (!parser.parse(grammar_str) || parser.rules.empty()) {
 
 
 
 
990
  fprintf(stderr, "%s: failed to parse grammar\n", __func__);
991
  return nullptr;
992
  }
 
1050
  } while (true);
1051
 
1052
  std::vector<llama_token> vec_trigger_tokens;
1053
+ std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
1054
  for (size_t i = 0; i < num_trigger_tokens; i++) {
1055
  GGML_ASSERT(trigger_tokens != nullptr);
1056
  vec_trigger_tokens.push_back(trigger_tokens[i]);
1057
  }
1058
+ for (size_t i = 0; i < num_trigger_patterns; i++) {
1059
+ GGML_ASSERT(trigger_patterns != nullptr);
1060
+ auto & trigger = vec_trigger_patterns.emplace_back();
1061
+ trigger.pattern = trigger_patterns[i];
1062
+ trigger.regex = std::regex(trigger.pattern);
1063
  }
1064
 
1065
  // Important: vec_rules has to be moved here, not copied, because stacks contains
 
1074
  /* .awaiting_trigger = */ lazy,
1075
  /* .trigger_buffer = */ "",
1076
  std::move(vec_trigger_tokens),
1077
+ std::move(vec_trigger_patterns),
1078
  };
1079
  }
1080
 
 
1087
  }
1088
 
1089
  struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1090
+ auto * result = new llama_grammar {
1091
  grammar.vocab,
1092
  grammar.rules,
1093
  grammar.stacks,
 
1096
  grammar.awaiting_trigger,
1097
  grammar.trigger_buffer,
1098
  grammar.trigger_tokens,
1099
+ grammar.trigger_patterns,
1100
  };
1101
 
1102
  // redirect elements in stacks to point to new rules
 
1171
  LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
1172
  return;
1173
  } else {
 
1174
  grammar.trigger_buffer += piece;
1175
+
1176
+ std::smatch match;
1177
+ for (const auto & trigger_pattern : grammar.trigger_patterns) {
1178
+ if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
1179
  grammar.awaiting_trigger = false;
1180
+ // get from the first match to the end of the string
1181
+ auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
1182
+ // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
1183
  grammar.trigger_buffer.clear();
1184
  llama_grammar_accept_str(grammar, constrained_str);
1185
+ LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
1186
  return;
1187
  }
1188
  }
1189
+ LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
1190
  return;
1191
  }
1192
  }
examples/talk-llama/llama-grammar.h CHANGED
@@ -3,6 +3,7 @@
3
  #include "llama.h"
4
 
5
  #include <map>
 
6
  #include <string>
7
  #include <vector>
8
 
@@ -105,6 +106,11 @@ struct llama_grammar_parser {
105
  void print(FILE * file);
106
  };
107
 
 
 
 
 
 
108
  struct llama_grammar {
109
  // note: allow null vocab for testing (not great)
110
  const llama_vocab * vocab;
@@ -116,13 +122,16 @@ struct llama_grammar {
116
  llama_partial_utf8 partial_utf8;
117
 
118
  // lazy grammars wait for trigger words or tokens before constraining the sampling.
119
- // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
120
  // (useful e.g. for tool_choice=required)
121
  bool lazy = false;
122
  bool awaiting_trigger = false; // Initialized to true for lazy grammars only
123
  std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
124
  std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
125
- std::vector<std::string> trigger_words;
 
 
 
126
  };
127
 
128
  //
@@ -141,8 +150,8 @@ struct llama_grammar * llama_grammar_init_impl(
141
  const char * grammar_str,
142
  const char * grammar_root,
143
  bool lazy,
144
- const char ** trigger_words,
145
- size_t num_trigger_words,
146
  const llama_token * trigger_tokens,
147
  size_t num_trigger_tokens);
148
 
 
3
  #include "llama.h"
4
 
5
  #include <map>
6
+ #include <regex>
7
  #include <string>
8
  #include <vector>
9
 
 
106
  void print(FILE * file);
107
  };
108
 
109
+ struct llama_grammar_trigger_pattern {
110
+ std::string pattern;
111
+ std::regex regex;
112
+ };
113
+
114
  struct llama_grammar {
115
  // note: allow null vocab for testing (not great)
116
  const llama_vocab * vocab;
 
122
  llama_partial_utf8 partial_utf8;
123
 
124
  // lazy grammars wait for trigger words or tokens before constraining the sampling.
125
+ // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
126
  // (useful e.g. for tool_choice=required)
127
  bool lazy = false;
128
  bool awaiting_trigger = false; // Initialized to true for lazy grammars only
129
  std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
130
  std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
131
+ std::vector<llama_grammar_trigger_pattern>
132
+ trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
133
+ // string, and the grammar will be given the string from the first match group onwards.
134
+
135
  };
136
 
137
  //
 
150
  const char * grammar_str,
151
  const char * grammar_root,
152
  bool lazy,
153
+ const char ** trigger_patterns,
154
+ size_t num_trigger_patterns,
155
  const llama_token * trigger_tokens,
156
  size_t num_trigger_tokens);
157
 
examples/talk-llama/llama-graph.cpp ADDED
@@ -0,0 +1,1706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-graph.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-cparams.h"
6
+ #include "llama-kv-cache.h"
7
+
8
+ #include <cassert>
9
+ #include <cmath>
10
+ #include <cstring>
11
+
12
+ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13
+ // TODO move to hparams if a T5 variant appears that uses a different value
14
+ const int64_t max_distance = 128;
15
+
16
+ if (bidirectional) {
17
+ n_buckets >>= 1;
18
+ }
19
+
20
+ const int64_t max_exact = n_buckets >> 1;
21
+
22
+ int32_t relative_position = x - y;
23
+ int32_t relative_bucket = 0;
24
+
25
+ if (bidirectional) {
26
+ relative_bucket += (relative_position > 0) * n_buckets;
27
+ relative_position = abs(relative_position);
28
+ } else {
29
+ relative_position = -std::min<int32_t>(relative_position, 0);
30
+ }
31
+
32
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
33
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
34
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35
+
36
+ return relative_bucket;
37
+ }
38
+
39
+ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
40
+ if (ubatch->token) {
41
+ const int64_t n_tokens = ubatch->n_tokens;
42
+
43
+ ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
44
+ }
45
+
46
+ if (ubatch->embd) {
47
+ const int64_t n_embd = embd->ne[0];
48
+ const int64_t n_tokens = ubatch->n_tokens;
49
+
50
+ ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
51
+ }
52
+ }
53
+
54
+ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
55
+ if (ubatch->pos && pos) {
56
+ const int64_t n_tokens = ubatch->n_tokens;
57
+
58
+ ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
59
+ }
60
+ }
61
+
62
+ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
63
+ if (ubatch->pos && attn_scale) {
64
+ const int64_t n_tokens = ubatch->n_tokens;
65
+
66
+ std::vector<float> attn_scale_data(n_tokens, 0.0f);
67
+ for (int i = 0; i < n_tokens; ++i) {
68
+ const float pos = ubatch->pos[i];
69
+ attn_scale_data[i] = std::log(
70
+ std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
71
+ ) * f_attn_temp_scale + 1.0;
72
+ }
73
+
74
+ ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
75
+ }
76
+ }
77
+
78
+ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
79
+ if (pos_bucket) {
80
+ const int64_t n_tokens = ubatch->n_tokens;
81
+
82
+ GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
83
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
84
+
85
+ int32_t * data = (int32_t *) pos_bucket->data;
86
+
87
+ for (int h = 0; h < 1; ++h) {
88
+ for (int j = 0; j < n_tokens; ++j) {
89
+ for (int i = 0; i < n_tokens; ++i) {
90
+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
91
+ }
92
+ }
93
+ }
94
+ }
95
+ }
96
+
97
+ void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
98
+ if (pos_bucket) {
99
+ const int64_t n_tokens = ubatch->n_tokens;
100
+
101
+ GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
102
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
103
+
104
+ int32_t * data = (int32_t *) pos_bucket->data;
105
+
106
+ const int64_t n_kv = kv_self->n;
107
+
108
+ for (int h = 0; h < 1; ++h) {
109
+ for (int j = 0; j < n_tokens; ++j) {
110
+ for (int i = 0; i < n_kv; ++i) {
111
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
112
+ }
113
+ }
114
+ }
115
+ }
116
+ }
117
+
118
+ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
119
+ if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
120
+ //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
121
+
122
+ if (!out_ids) {
123
+ LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
124
+ } else {
125
+ const int64_t n_tokens = ubatch->n_tokens;
126
+
127
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
128
+ int32_t * data = (int32_t *) out_ids->data;
129
+
130
+ if (n_outputs == n_tokens) {
131
+ for (int i = 0; i < n_tokens; ++i) {
132
+ data[i] = i;
133
+ }
134
+ } else if (ubatch->output) {
135
+ int32_t n_outputs = 0;
136
+ for (int i = 0; i < n_tokens; ++i) {
137
+ if (ubatch->output[i]) {
138
+ data[n_outputs++] = i;
139
+ }
140
+ }
141
+ // the graph needs to have been passed the correct number of outputs
142
+ GGML_ASSERT(n_outputs == n_outputs);
143
+ } else if (n_outputs == 1) {
144
+ // only keep last output
145
+ data[0] = n_tokens - 1;
146
+ } else {
147
+ GGML_ASSERT(n_outputs == 0);
148
+ }
149
+ }
150
+ }
151
+ }
152
+
153
+ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
154
+ if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
155
+ const int64_t n_tokens = ubatch->n_tokens;
156
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
157
+ const int64_t n_seqs = ubatch->n_seqs;
158
+
159
+ GGML_ASSERT(mean);
160
+ GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
161
+
162
+ float * data = (float *) mean->data;
163
+ memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
164
+
165
+ std::vector<uint64_t> sum(n_tokens, 0);
166
+
167
+ for (int s = 0; s < n_seqs; ++s) {
168
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
169
+
170
+ // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
171
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
172
+
173
+ sum[seq_id] += ubatch->n_seq_tokens;
174
+ }
175
+
176
+ std::vector<float> div(n_tokens, 0.0f);
177
+ for (int i = 0; i < n_tokens; ++i) {
178
+ const uint64_t s = sum[i];
179
+ if (s > 0) {
180
+ div[i] = 1.0f/float(s);
181
+ }
182
+ }
183
+
184
+ for (int s = 0; s < n_seqs; ++s) {
185
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
186
+
187
+ for (int i = 0; i < n_seq_tokens; ++i) {
188
+ data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
189
+ }
190
+ }
191
+ }
192
+ }
193
+
194
+ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
195
+ if (cparams.embeddings && (
196
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
197
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
198
+ const int64_t n_tokens = ubatch->n_tokens;
199
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
200
+ const int64_t n_seqs = ubatch->n_seqs;
201
+
202
+ GGML_ASSERT(cls);
203
+ GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
204
+
205
+ uint32_t * data = (uint32_t *) cls->data;
206
+ memset(cls->data, 0, n_tokens * ggml_element_size(cls));
207
+
208
+ for (int s = 0; s < n_seqs; ++s) {
209
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
210
+
211
+ // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
212
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
213
+
214
+ for (int i = 0; i < n_seq_tokens; ++i) {
215
+ const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
216
+
217
+ if (pos == 0) {
218
+ data[seq_id] = s*n_seq_tokens + i;
219
+ }
220
+ }
221
+ }
222
+ }
223
+
224
+ if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
225
+ const int64_t n_tokens = ubatch->n_tokens;
226
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
227
+ const int64_t n_seqs = ubatch->n_seqs;
228
+
229
+ GGML_ASSERT(cls);
230
+ GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
231
+
232
+ uint32_t * data = (uint32_t *) cls->data;
233
+ memset(cls->data, 0, n_tokens * ggml_element_size(cls));
234
+
235
+ std::vector<int> last_pos(n_tokens, -1);
236
+ std::vector<int> last_row(n_tokens, -1);
237
+
238
+ for (int s = 0; s < n_seqs; ++s) {
239
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
240
+
241
+ // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
242
+ GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
243
+
244
+ for (int i = 0; i < n_seq_tokens; ++i) {
245
+ const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
246
+
247
+ if (pos >= last_pos[seq_id]) {
248
+ last_pos[seq_id] = pos;
249
+ last_row[seq_id] = s*n_seq_tokens + i;
250
+ }
251
+ }
252
+ }
253
+
254
+ for (int i = 0; i < n_tokens; ++i) {
255
+ if (last_row[i] >= 0) {
256
+ data[i] = last_row[i];
257
+ }
258
+ }
259
+ }
260
+ }
261
+
262
+ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
263
+ GGML_UNUSED(ubatch);
264
+
265
+ const int64_t n_kv = kv_self->n;
266
+
267
+ if (s_copy) {
268
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
269
+ int32_t * data = (int32_t *) s_copy->data;
270
+
271
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
272
+ for (uint32_t i = 0; i < n_kv; ++i) {
273
+ const uint32_t cell_id = i + kv_self->head;
274
+
275
+ //////////////////////////////////////////////
276
+ // TODO: this should not mutate the KV cache !
277
+ llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
278
+
279
+ // prevent out-of-bound sources
280
+ if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
281
+ kv_cell.src = cell_id;
282
+ }
283
+
284
+ data[i] = kv_cell.src;
285
+
286
+ // TODO: do not mutate the KV cache
287
+ // ensure copy only happens once
288
+ if (kv_cell.src != (int32_t) cell_id) {
289
+ kv_cell.src = cell_id;
290
+ }
291
+ }
292
+ }
293
+ }
294
+
295
+ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
296
+ GGML_UNUSED(ubatch);
297
+
298
+ const int64_t n_kv = kv_self->n;
299
+
300
+ if (s_mask) {
301
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
302
+ float * data = (float *) s_mask->data;
303
+
304
+ // clear unused states
305
+ for (int i = 0; i < n_kv; ++i) {
306
+ const uint32_t cell_id = i + kv_self->head;
307
+
308
+ //////////////////////////////////////////////
309
+ // TODO: this should not mutate the KV cache !
310
+ llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
311
+
312
+ data[i] = (float) (kv_cell.src >= 0);
313
+
314
+ // only clear once
315
+ if (kv_cell.src < 0) {
316
+ kv_cell.src = cell_id;
317
+ }
318
+ }
319
+ }
320
+ }
321
+
322
+ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
323
+ GGML_UNUSED(ubatch);
324
+
325
+ if (cross_embd && !cross->v_embd.empty()) {
326
+ assert(cross_embd->type == GGML_TYPE_F32);
327
+
328
+ ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
329
+ }
330
+ }
331
+
332
+ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
333
+ if (kq_mask) {
334
+ if (cparams.causal_attn) {
335
+ const int64_t n_kv = ubatch->n_tokens;
336
+ const int64_t n_tokens = ubatch->n_tokens;
337
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
338
+ const int64_t n_seqs = ubatch->n_seqs;
339
+
340
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
341
+ float * data = (float *) kq_mask->data;
342
+
343
+ for (int h = 0; h < 1; ++h) {
344
+ for (int s1 = 0; s1 < n_seqs; ++s1) {
345
+ const llama_seq_id seq_id = ubatch->seq_id[s1][0];
346
+
347
+ for (int j = 0; j < n_seq_tokens; ++j) {
348
+ const int32_t tj = s1*n_seq_tokens + j;
349
+
350
+ for (int s0 = 0; s0 < n_seqs; ++s0) {
351
+ for (int i = 0; i < n_seq_tokens; ++i) {
352
+ const int32_t ti = s0*n_seq_tokens + i;
353
+ float f = -INFINITY;
354
+
355
+ for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
356
+ if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
357
+ if (hparams.use_alibi) {
358
+ f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
359
+ } else {
360
+ f = 0.0f;
361
+ }
362
+ break;
363
+ }
364
+ }
365
+
366
+ data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
367
+ }
368
+ }
369
+ }
370
+ }
371
+ }
372
+ } else {
373
+ const int64_t n_tokens = ubatch->n_tokens;
374
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
375
+ const int64_t n_seqs = ubatch->n_seqs;
376
+ const int64_t n_stride = ubatch->n_tokens;
377
+
378
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
379
+
380
+ float * data = (float *) kq_mask->data;
381
+
382
+ for (int h = 0; h < 1; ++h) {
383
+ for (int s1 = 0; s1 < n_seqs; ++s1) {
384
+ const llama_seq_id seq_id = ubatch->seq_id[s1][0];
385
+
386
+ for (int j = 0; j < n_seq_tokens; ++j) {
387
+ const int32_t tj = s1*n_seq_tokens + j;
388
+
389
+ for (int s0 = 0; s0 < n_seqs; ++s0) {
390
+ for (int i = 0; i < n_seq_tokens; ++i) {
391
+ const int32_t ti = s0*n_seq_tokens + i;
392
+ float f = -INFINITY;
393
+
394
+ for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
395
+ if (ubatch->seq_id[s0][s] == seq_id) {
396
+ if (hparams.use_alibi) {
397
+ f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
398
+ } else {
399
+ f = 0.0f;
400
+ }
401
+ break;
402
+ }
403
+ }
404
+
405
+ data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
406
+ }
407
+ }
408
+
409
+ for (int i = n_tokens; i < n_stride; ++i) {
410
+ data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
411
+ }
412
+ }
413
+ }
414
+ }
415
+ }
416
+ }
417
+ }
418
+
419
+ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
420
+ if (self_kq_mask || self_kq_mask_swa) {
421
+ const int64_t n_kv = kv_self->n;
422
+ const int64_t n_tokens = ubatch->n_tokens;
423
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
424
+ const int64_t n_seqs = ubatch->n_seqs;
425
+
426
+ float * data = nullptr;
427
+ float * data_swa = nullptr;
428
+
429
+ if (self_kq_mask) {
430
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
431
+ data = (float *) self_kq_mask->data;
432
+ }
433
+
434
+ if (self_kq_mask_swa) {
435
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
436
+ data_swa = (float *) self_kq_mask_swa->data;
437
+ }
438
+
439
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
440
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
441
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
442
+ // Causal mask:
443
+ // xxx-------
444
+ // xxxx------
445
+ // xxxxx-----
446
+ // Non-causal mask:
447
+ // xxxxx-----
448
+ // xxxxx-----
449
+ // xxxxx-----
450
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
451
+ for (int h = 0; h < 1; ++h) {
452
+ for (int s = 0; s < n_seqs; ++s) {
453
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
454
+
455
+ for (int j = 0; j < n_seq_tokens; ++j) {
456
+ const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
457
+ for (int i = 0; i < n_kv; ++i) {
458
+ float f;
459
+ // mask the token if:
460
+ if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
461
+ || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
462
+ ) {
463
+ f = -INFINITY;
464
+ } else {
465
+ if (hparams.use_alibi) {
466
+ f = -std::abs(kv_self->cells[i].pos - pos);
467
+ } else {
468
+ f = 0.0f;
469
+ }
470
+ }
471
+
472
+ if (data) {
473
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
474
+ }
475
+
476
+ // may need to cut off old tokens for sliding window
477
+ // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
478
+ if (data_swa) {
479
+ if (hparams.n_attn_chunk) {
480
+ llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
481
+ if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
482
+ f = -INFINITY;
483
+ }
484
+ } else {
485
+ if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
486
+ f = -INFINITY;
487
+ }
488
+ }
489
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
490
+ }
491
+ }
492
+ }
493
+ }
494
+
495
+ // mask padded tokens
496
+ if (data) {
497
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
498
+ for (int j = 0; j < n_kv; ++j) {
499
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
500
+ }
501
+ }
502
+ }
503
+
504
+ // mask padded tokens
505
+ if (data_swa) {
506
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
507
+ for (int j = 0; j < n_kv; ++j) {
508
+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
509
+ }
510
+ }
511
+ }
512
+ }
513
+ }
514
+ }
515
+
516
+ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
517
+ if (cross_kq_mask) {
518
+ const int64_t n_enc = cross_kq_mask->ne[0];
519
+ const int64_t n_tokens = ubatch->n_tokens;
520
+
521
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
522
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
523
+
524
+ float * data = (float *) cross_kq_mask->data;
525
+
526
+ for (int h = 0; h < 1; ++h) {
527
+ for (int j = 0; j < n_tokens; ++j) {
528
+ for (int i = 0; i < n_enc; ++i) {
529
+ float f = -INFINITY;
530
+ for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
531
+ const llama_seq_id seq_id = ubatch->seq_id[j][s];
532
+ if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
533
+ f = 0.0f;
534
+ }
535
+ }
536
+ data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
537
+ }
538
+ }
539
+
540
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
541
+ for (int j = 0; j < n_enc; ++j) {
542
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
543
+ }
544
+ }
545
+ }
546
+ }
547
+ }
548
+
549
+ //
550
+ // llm_graph_context
551
+ //
552
+
553
+ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
554
+ arch (params.arch),
555
+ hparams (params.hparams),
556
+ cparams (params.cparams),
557
+ ubatch (params.ubatch),
558
+ n_embd (hparams.n_embd),
559
+ n_layer (hparams.n_layer),
560
+ n_rot (hparams.n_rot),
561
+ n_ctx (cparams.n_ctx),
562
+ n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
563
+ n_head (hparams.n_head()),
564
+ n_head_kv (hparams.n_head_kv()),
565
+ n_embd_head_k (hparams.n_embd_head_k),
566
+ n_embd_k_gqa (hparams.n_embd_k_gqa()),
567
+ n_embd_head_v (hparams.n_embd_head_v),
568
+ n_embd_v_gqa (hparams.n_embd_v_gqa()),
569
+ n_expert (hparams.n_expert),
570
+ n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
571
+ freq_base (cparams.rope_freq_base),
572
+ freq_scale (cparams.rope_freq_scale),
573
+ ext_factor (cparams.yarn_ext_factor),
574
+ attn_factor (cparams.yarn_attn_factor),
575
+ beta_fast (cparams.yarn_beta_fast),
576
+ beta_slow (cparams.yarn_beta_slow),
577
+ norm_eps (hparams.f_norm_eps),
578
+ norm_rms_eps (hparams.f_norm_rms_eps),
579
+ n_tokens (ubatch.n_tokens),
580
+ n_outputs (params.n_outputs),
581
+ n_ctx_orig (cparams.n_ctx_orig_yarn),
582
+ pooling_type (cparams.pooling_type),
583
+ rope_type (hparams.rope_type),
584
+ ctx0 (params.ctx),
585
+ sched (params.sched),
586
+ backend_cpu (params.backend_cpu),
587
+ cvec (params.cvec),
588
+ loras (params.loras),
589
+ memory (params.memory),
590
+ cross (params.cross),
591
+ cb_func (params.cb),
592
+ res (std::make_unique<llm_graph_result>()) {
593
+ }
594
+
595
+ int64_t llm_graph_context::n_pos_per_token() const {
596
+ return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
597
+ }
598
+
599
+ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
600
+ if (cb_func) {
601
+ cb_func(ubatch, cur, name, il);
602
+ }
603
+ }
604
+
605
+ ggml_tensor * llm_graph_context::build_cvec(
606
+ ggml_tensor * cur,
607
+ int il) const {
608
+ return cvec->apply_to(ctx0, cur, il);
609
+ }
610
+
611
+ ggml_tensor * llm_graph_context::build_lora_mm(
612
+ ggml_tensor * w,
613
+ ggml_tensor * cur) const {
614
+ ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
615
+
616
+ for (const auto & lora : *loras) {
617
+ llama_adapter_lora_weight * lw = lora.first->get_weight(w);
618
+ if (lw == nullptr) {
619
+ continue;
620
+ }
621
+
622
+ const float adapter_scale = lora.second;
623
+ const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
624
+
625
+ ggml_tensor * ab_cur = ggml_mul_mat(
626
+ ctx0, lw->b,
627
+ ggml_mul_mat(ctx0, lw->a, cur)
628
+ );
629
+
630
+ ab_cur = ggml_scale(ctx0, ab_cur, scale);
631
+ res = ggml_add(ctx0, res, ab_cur);
632
+ }
633
+
634
+ return res;
635
+ }
636
+
637
+ ggml_tensor * llm_graph_context::build_lora_mm_id(
638
+ ggml_tensor * w, // ggml_tensor * as
639
+ ggml_tensor * cur, // ggml_tensor * b
640
+ ggml_tensor * ids) const {
641
+ ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
642
+ for (const auto & lora : *loras) {
643
+ llama_adapter_lora_weight * lw = lora.first->get_weight(w);
644
+ if (lw == nullptr) {
645
+ continue;
646
+ }
647
+
648
+ const float alpha = lora.first->alpha;
649
+ const float rank = (float) lw->b->ne[0];
650
+ const float scale = alpha ? lora.second * alpha / rank : lora.second;
651
+
652
+ ggml_tensor * ab_cur = ggml_mul_mat_id(
653
+ ctx0, lw->b,
654
+ ggml_mul_mat_id(ctx0, lw->a, cur, ids),
655
+ ids
656
+ );
657
+
658
+ ab_cur = ggml_scale(ctx0, ab_cur, scale);
659
+ res = ggml_add(ctx0, res, ab_cur);
660
+ }
661
+
662
+ return res;
663
+ }
664
+
665
+ ggml_tensor * llm_graph_context::build_norm(
666
+ ggml_tensor * cur,
667
+ ggml_tensor * mw,
668
+ ggml_tensor * mb,
669
+ llm_norm_type type,
670
+ int il) const {
671
+ switch (type) {
672
+ case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
673
+ case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
674
+ case LLM_NORM_GROUP:
675
+ {
676
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
677
+ cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
678
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
679
+ } break;
680
+ }
681
+
682
+ if (mw || mb) {
683
+ cb(cur, "norm", il);
684
+ }
685
+
686
+ if (mw) {
687
+ cur = ggml_mul(ctx0, cur, mw);
688
+ if (mb) {
689
+ cb(cur, "norm_w", il);
690
+ }
691
+ }
692
+
693
+ if (mb) {
694
+ cur = ggml_add(ctx0, cur, mb);
695
+ }
696
+
697
+ return cur;
698
+ }
699
+
700
+ ggml_tensor * llm_graph_context::build_ffn(
701
+ ggml_tensor * cur,
702
+ ggml_tensor * up,
703
+ ggml_tensor * up_b,
704
+ ggml_tensor * up_s,
705
+ ggml_tensor * gate,
706
+ ggml_tensor * gate_b,
707
+ ggml_tensor * gate_s,
708
+ ggml_tensor * down,
709
+ ggml_tensor * down_b,
710
+ ggml_tensor * down_s,
711
+ ggml_tensor * act_scales,
712
+ llm_ffn_op_type type_op,
713
+ llm_ffn_gate_type type_gate,
714
+ int il) const {
715
+ ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
716
+ cb(tmp, "ffn_up", il);
717
+
718
+ if (up_b) {
719
+ tmp = ggml_add(ctx0, tmp, up_b);
720
+ cb(tmp, "ffn_up_b", il);
721
+ }
722
+
723
+ if (up_s) {
724
+ tmp = ggml_mul(ctx0, tmp, up_s);
725
+ cb(tmp, "ffn_up_s", il);
726
+ }
727
+
728
+ if (gate) {
729
+ switch (type_gate) {
730
+ case LLM_FFN_SEQ:
731
+ {
732
+ cur = build_lora_mm(gate, tmp);
733
+ cb(cur, "ffn_gate", il);
734
+ } break;
735
+ case LLM_FFN_PAR:
736
+ {
737
+ cur = build_lora_mm(gate, cur);
738
+ cb(cur, "ffn_gate", il);
739
+ } break;
740
+ }
741
+
742
+ if (gate_b) {
743
+ cur = ggml_add(ctx0, cur, gate_b);
744
+ cb(cur, "ffn_gate_b", il);
745
+ }
746
+
747
+ if (gate_s) {
748
+ cur = ggml_mul(ctx0, cur, gate_s);
749
+ cb(cur, "ffn_gate_s", il);
750
+ }
751
+
752
+ } else {
753
+ cur = tmp;
754
+ }
755
+
756
+ switch (type_op) {
757
+ case LLM_FFN_SILU:
758
+ {
759
+ cur = ggml_silu(ctx0, cur);
760
+ cb(cur, "ffn_silu", il);
761
+ } break;
762
+ case LLM_FFN_GELU:
763
+ {
764
+ cur = ggml_gelu(ctx0, cur);
765
+ cb(cur, "ffn_gelu", il);
766
+ if (act_scales != NULL) {
767
+ cur = ggml_div(ctx0, cur, act_scales);
768
+ cb(cur, "ffn_act", il);
769
+ }
770
+ } break;
771
+ case LLM_FFN_RELU:
772
+ {
773
+ cur = ggml_relu(ctx0, cur);
774
+ cb(cur, "ffn_relu", il);
775
+ } break;
776
+ case LLM_FFN_RELU_SQR:
777
+ {
778
+ cur = ggml_relu(ctx0, cur);
779
+ cb(cur, "ffn_relu", il);
780
+
781
+ cur = ggml_sqr(ctx0, cur);
782
+ cb(cur, "ffn_sqr(relu)", il);
783
+ } break;
784
+ case LLM_FFN_SWIGLU:
785
+ {
786
+ // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
787
+ int64_t split_point = cur->ne[0] / 2;
788
+ ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
789
+ ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
790
+
791
+ x0 = ggml_silu(ctx0, x0);
792
+ cb(cur, "ffn_silu", il);
793
+
794
+ cur = ggml_mul(ctx0, x0, x1);
795
+ cb(cur, "ffn_mul", il);
796
+ } break;
797
+ }
798
+
799
+ if (type_gate == LLM_FFN_PAR) {
800
+ cur = ggml_mul(ctx0, cur, tmp);
801
+ cb(cur, "ffn_gate_par", il);
802
+ }
803
+
804
+ if (down) {
805
+ cur = build_lora_mm(down, cur);
806
+ }
807
+
808
+ if (down_b) {
809
+ cb(cur, "ffn_down", il);
810
+ }
811
+
812
+ if (down_b) {
813
+ cur = ggml_add(ctx0, cur, down_b);
814
+ }
815
+
816
+ if (down_s) {
817
+ cur = ggml_mul(ctx0, cur, down_s);
818
+ cb(cur, "ffn_down_s", il);
819
+ }
820
+
821
+ return cur;
822
+ }
823
+
824
+ ggml_tensor * llm_graph_context::build_moe_ffn(
825
+ ggml_tensor * cur,
826
+ ggml_tensor * gate_inp,
827
+ ggml_tensor * up_exps,
828
+ ggml_tensor * gate_exps,
829
+ ggml_tensor * down_exps,
830
+ ggml_tensor * exp_probs_b,
831
+ int64_t n_expert,
832
+ int64_t n_expert_used,
833
+ llm_ffn_op_type type_op,
834
+ bool norm_w,
835
+ bool scale_w,
836
+ float w_scale,
837
+ llama_expert_gating_func_type gating_op,
838
+ int il) const {
839
+ const int64_t n_embd = cur->ne[0];
840
+ const int64_t n_tokens = cur->ne[1];
841
+ const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
842
+
843
+ ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
844
+ cb(logits, "ffn_moe_logits", il);
845
+
846
+ ggml_tensor * probs = nullptr;
847
+ switch (gating_op) {
848
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
849
+ {
850
+ probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
851
+ } break;
852
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
853
+ {
854
+ probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
855
+ } break;
856
+ default:
857
+ GGML_ABORT("fatal error");
858
+ }
859
+ cb(probs, "ffn_moe_probs", il);
860
+
861
+ // add experts selection bias - introduced in DeepSeek V3
862
+ // leave probs unbiased as it's later used to get expert weights
863
+ ggml_tensor * selection_probs = probs;
864
+ if (exp_probs_b != nullptr) {
865
+ selection_probs = ggml_add(ctx0, probs, exp_probs_b);
866
+ cb(selection_probs, "ffn_moe_probs_biased", il);
867
+ }
868
+
869
+ // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
870
+ // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
871
+ if (arch == LLM_ARCH_LLAMA4) {
872
+ selection_probs = logits;
873
+ }
874
+
875
+ // select experts
876
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
877
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
878
+ cb(selected_experts, "ffn_moe_topk", il);
879
+
880
+ ggml_tensor * weights = ggml_get_rows(ctx0,
881
+ ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
882
+ cb(weights, "ffn_moe_weights", il);
883
+
884
+ if (norm_w) {
885
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
886
+
887
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
888
+ cb(weights_sum, "ffn_moe_weights_sum", il);
889
+
890
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
891
+ cb(weights, "ffn_moe_weights_norm", il);
892
+
893
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
894
+ }
895
+ if (scale_w) {
896
+ weights = ggml_scale(ctx0, weights, w_scale);
897
+ cb(weights, "ffn_moe_weights_scaled", il);
898
+ }
899
+
900
+ cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
901
+
902
+ if (weight_before_ffn) {
903
+ // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
904
+ ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
905
+ repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
906
+ cur = ggml_mul(ctx0, repeated, weights);
907
+ cb(cur, "ffn_moe_weighted", il);
908
+ }
909
+
910
+ ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
911
+ cb(up, "ffn_moe_up", il);
912
+
913
+ ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
914
+ cb(gate, "ffn_moe_gate", il);
915
+
916
+ switch (type_op) {
917
+ case LLM_FFN_SILU:
918
+ {
919
+ gate = ggml_silu(ctx0, gate);
920
+ cb(gate, "ffn_moe_silu", il);
921
+ } break;
922
+ case LLM_FFN_GELU:
923
+ {
924
+ gate = ggml_gelu(ctx0, gate);
925
+ cb(gate, "ffn_moe_gelu", il);
926
+ } break;
927
+ default:
928
+ GGML_ABORT("fatal error");
929
+ }
930
+
931
+ ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
932
+ cb(par, "ffn_moe_gate_par", il);
933
+
934
+ ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
935
+ cb(experts, "ffn_moe_down", il);
936
+
937
+ if (!weight_before_ffn) {
938
+ experts = ggml_mul(ctx0, experts, weights);
939
+ cb(cur, "ffn_moe_weighted", il);
940
+ }
941
+
942
+ // aggregate experts
943
+ ggml_tensor * moe_out = nullptr;
944
+ for (int i = 0; i < n_expert_used; ++i) {
945
+ ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
946
+ experts->nb[2], i*experts->nb[1]);
947
+
948
+ if (i == 0) {
949
+ moe_out = cur_expert;
950
+ } else {
951
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
952
+ }
953
+ }
954
+
955
+ if (n_expert_used == 1) {
956
+ // avoid returning a non-contiguous tensor
957
+ moe_out = ggml_cont(ctx0, moe_out);
958
+ }
959
+
960
+ cb(moe_out, "ffn_moe_out", il);
961
+
962
+ return moe_out;
963
+ }
964
+
965
+ // input embeddings with optional lora
966
+ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
967
+ const int64_t n_embd = hparams.n_embd;
968
+
969
+ auto inp = std::make_unique<llm_graph_input_embd>();
970
+
971
+ ggml_tensor * cur = nullptr;
972
+
973
+ if (ubatch.token) {
974
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
975
+ //cb(inp->tokens, "inp_tokens", -1);
976
+ ggml_set_input(inp->tokens);
977
+
978
+ cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
979
+
980
+ // apply lora for embedding tokens if needed
981
+ for (const auto & lora : *loras) {
982
+ llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
983
+ if (lw == nullptr) {
984
+ continue;
985
+ }
986
+
987
+ const float adapter_scale = lora.second;
988
+ const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
989
+
990
+ ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
991
+ ctx0, lw->b, // non-transposed lora_b
992
+ ggml_get_rows(ctx0, lw->a, inp->tokens)
993
+ ), scale);
994
+
995
+ cur = ggml_add(ctx0, cur, inpL_delta);
996
+ }
997
+ } else {
998
+ inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
999
+ ggml_set_input(inp->embd);
1000
+
1001
+ cur = inp->embd;
1002
+ }
1003
+
1004
+ // For Granite architecture
1005
+ if (hparams.f_embedding_scale != 0.0f) {
1006
+ cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1007
+ }
1008
+
1009
+ cb(cur, "inp_embd", -1);
1010
+
1011
+ res->add_input(std::move(inp));
1012
+
1013
+ return cur;
1014
+ }
1015
+
1016
+ ggml_tensor * llm_graph_context::build_inp_pos() const {
1017
+ auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
1018
+
1019
+ auto & cur = inp->pos;
1020
+
1021
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
1022
+ ggml_set_input(cur);
1023
+
1024
+ res->add_input(std::move(inp));
1025
+
1026
+ return cur;
1027
+ }
1028
+
1029
+ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1030
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1031
+
1032
+ auto & cur = inp->attn_scale;
1033
+
1034
+ cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
1035
+ ggml_set_input(cur);
1036
+
1037
+ res->add_input(std::move(inp));
1038
+
1039
+ return cur;
1040
+ }
1041
+
1042
+ ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1043
+ auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
1044
+
1045
+ auto & cur = inp->out_ids;
1046
+
1047
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
1048
+ ggml_set_input(cur);
1049
+
1050
+ res->add_input(std::move(inp));
1051
+
1052
+ return cur;
1053
+ }
1054
+
1055
+ ggml_tensor * llm_graph_context::build_inp_mean() const {
1056
+ auto inp = std::make_unique<llm_graph_input_mean>(cparams);
1057
+
1058
+ auto & cur = inp->mean;
1059
+
1060
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
1061
+ ggml_set_input(cur);
1062
+
1063
+ res->add_input(std::move(inp));
1064
+
1065
+ return cur;
1066
+ }
1067
+
1068
+ ggml_tensor * llm_graph_context::build_inp_cls() const {
1069
+ auto inp = std::make_unique<llm_graph_input_cls>(cparams);
1070
+
1071
+ auto & cur = inp->cls;
1072
+
1073
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
1074
+ ggml_set_input(cur);
1075
+
1076
+ res->add_input(std::move(inp));
1077
+
1078
+ return cur;
1079
+ }
1080
+
1081
+ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1082
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1083
+
1084
+ auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1085
+
1086
+ const auto n_kv = kv_self->n;
1087
+
1088
+ auto & cur = inp->s_copy;
1089
+
1090
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1091
+ ggml_set_input(cur);
1092
+
1093
+ res->add_input(std::move(inp));
1094
+
1095
+ return cur;
1096
+ }
1097
+
1098
+ ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1099
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1100
+
1101
+ auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1102
+
1103
+ const auto n_kv = kv_self->n;
1104
+
1105
+ auto & cur = inp->s_mask;
1106
+
1107
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
1108
+ ggml_set_input(cur);
1109
+
1110
+ res->add_input(std::move(inp));
1111
+
1112
+ return cur;
1113
+ }
1114
+
1115
+ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1116
+ auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1117
+
1118
+ auto & cur = inp->cross_embd;
1119
+
1120
+ // if we have the output embeddings from the encoder, use them directly
1121
+ // TODO: needs more work to be correct, for now just use the tensor shape
1122
+ //if (cross->t_embd) {
1123
+ // cur = ggml_view_tensor(ctx0, cross->t_embd);
1124
+
1125
+ // return cur;
1126
+ //}
1127
+
1128
+ const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
1129
+ const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1130
+
1131
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1132
+ ggml_set_input(cur);
1133
+
1134
+ res->add_input(std::move(inp));
1135
+
1136
+ return cur;
1137
+ }
1138
+
1139
+ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1140
+ auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
1141
+
1142
+ auto & cur = inp->pos_bucket;
1143
+
1144
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1145
+ ggml_set_input(cur);
1146
+
1147
+ res->add_input(std::move(inp));
1148
+
1149
+ return cur;
1150
+ }
1151
+
1152
+ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1153
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1154
+
1155
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1156
+
1157
+ const auto n_kv = kv_self->n;
1158
+
1159
+ auto & cur = inp->pos_bucket;
1160
+
1161
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
1162
+ ggml_set_input(cur);
1163
+
1164
+ res->add_input(std::move(inp));
1165
+
1166
+ return cur;
1167
+ }
1168
+
1169
+ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
1170
+ ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
1171
+ cb(pos_bucket_1d, "pos_bucket_1d", -1);
1172
+
1173
+ ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
1174
+
1175
+ pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
1176
+ pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
1177
+ pos_bias = ggml_cont (ctx0, pos_bias);
1178
+
1179
+ cb(pos_bias, "pos_bias", -1);
1180
+
1181
+ return pos_bias;
1182
+ }
1183
+
1184
+ ggml_tensor * llm_graph_context::build_attn_mha(
1185
+ ggml_cgraph * gf,
1186
+ ggml_tensor * q,
1187
+ ggml_tensor * k,
1188
+ ggml_tensor * v,
1189
+ ggml_tensor * kq_b,
1190
+ ggml_tensor * kq_mask,
1191
+ ggml_tensor * v_mla,
1192
+ bool v_trans,
1193
+ float kq_scale) const {
1194
+ //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1195
+ //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1196
+
1197
+ //const int64_t n_head = hparams.n_head(il);
1198
+ //const int64_t n_head_kv = hparams.n_head_kv(il);
1199
+
1200
+ //const auto & n_embd_head_k = hparams.n_embd_head_k;
1201
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
1202
+
1203
+ const auto n_tokens = q->ne[1];
1204
+ const auto n_head = q->ne[2];
1205
+ const auto n_kv = k->ne[1];
1206
+
1207
+ ggml_tensor * cur;
1208
+
1209
+ // TODO: replace hardcoded padding with ggml-provided padding
1210
+ if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1211
+ GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1212
+
1213
+ if (v_trans) {
1214
+ v = ggml_transpose(ctx0, v);
1215
+ }
1216
+
1217
+ // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1218
+ if (k->type == GGML_TYPE_F32) {
1219
+ k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1220
+ }
1221
+
1222
+ if (v->type == GGML_TYPE_F32) {
1223
+ v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1224
+ }
1225
+
1226
+ cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1227
+ hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1228
+
1229
+ ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1230
+
1231
+ if (v_mla) {
1232
+ cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1233
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1234
+ }
1235
+
1236
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1237
+ } else {
1238
+ ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1239
+
1240
+ // note: this op tends to require high floating point range
1241
+ // while for some models F16 is enough, for others it is not, so we default to F32 here
1242
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
1243
+
1244
+ if (arch == LLM_ARCH_GROK) {
1245
+ // need to do the following:
1246
+ // multiply by attn_output_multiplyer of 0.08838834764831845
1247
+ // and then :
1248
+ // kq = 30 * tanh(kq / 30)
1249
+ // before the softmax below
1250
+
1251
+ kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
1252
+ kq = ggml_scale(ctx0, kq, 30);
1253
+ }
1254
+
1255
+ if (hparams.attn_soft_cap) {
1256
+ kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1257
+ kq = ggml_tanh (ctx0, kq);
1258
+ kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1259
+ }
1260
+
1261
+ if (kq_b) {
1262
+ kq = ggml_add(ctx0, kq, kq_b);
1263
+ }
1264
+
1265
+ kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1266
+
1267
+ if (!v_trans) {
1268
+ // note: avoid this branch
1269
+ v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1270
+ }
1271
+
1272
+ ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1273
+
1274
+ // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1275
+ if (v_mla) {
1276
+ kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1277
+ }
1278
+
1279
+ cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1280
+
1281
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1282
+
1283
+ if (!cparams.offload_kqv) {
1284
+ // all nodes between the KV store and the attention output are run on the CPU
1285
+ ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
1286
+ }
1287
+ }
1288
+
1289
+ ggml_build_forward_expand(gf, cur);
1290
+
1291
+ return cur;
1292
+ }
1293
+
1294
+ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1295
+ auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1296
+
1297
+ // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1298
+ inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1299
+ //cb(inp_kq_mask, "KQ_mask", -1);
1300
+ ggml_set_input(inp->kq_mask);
1301
+
1302
+ inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1303
+
1304
+ return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1305
+ }
1306
+
1307
+ ggml_tensor * llm_graph_context::build_attn(
1308
+ llm_graph_input_attn_no_cache * inp,
1309
+ ggml_cgraph * gf,
1310
+ ggml_tensor * wo,
1311
+ ggml_tensor * wo_b,
1312
+ ggml_tensor * q_cur,
1313
+ ggml_tensor * k_cur,
1314
+ ggml_tensor * v_cur,
1315
+ ggml_tensor * kq_b,
1316
+ ggml_tensor * v_mla,
1317
+ float kq_scale,
1318
+ int il) const {
1319
+ GGML_UNUSED(n_tokens);
1320
+
1321
+ // these nodes are added to the graph together so that they are not reordered
1322
+ // by doing so, the number of splits in the graph is reduced
1323
+ ggml_build_forward_expand(gf, q_cur);
1324
+ ggml_build_forward_expand(gf, k_cur);
1325
+ ggml_build_forward_expand(gf, v_cur);
1326
+
1327
+ const auto & kq_mask = inp->get_kq_mask();
1328
+
1329
+ ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1330
+ //cb(q, "q", il);
1331
+
1332
+ ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1333
+ //cb(k, "k", il);
1334
+
1335
+ ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1336
+ //cb(k, "v", il);
1337
+
1338
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1339
+
1340
+ cb(cur, "kqv_out", il);
1341
+
1342
+ if (wo) {
1343
+ cur = build_lora_mm(wo, cur);
1344
+ }
1345
+
1346
+ if (wo_b) {
1347
+ //cb(cur, "kqv_wo", il);
1348
+ }
1349
+
1350
+ if (wo_b) {
1351
+ cur = ggml_add(ctx0, cur, wo_b);
1352
+ }
1353
+
1354
+ return cur;
1355
+ }
1356
+
1357
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1358
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1359
+
1360
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1361
+
1362
+ const auto n_kv = kv_self->n;
1363
+
1364
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1365
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1366
+ ggml_set_input(inp->self_kq_mask);
1367
+
1368
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1369
+
1370
+ if (hparams.n_swa_pattern > 1) {
1371
+ GGML_ASSERT(hparams.n_swa > 0);
1372
+
1373
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1374
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1375
+ ggml_set_input(inp->self_kq_mask_swa);
1376
+
1377
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1378
+ }
1379
+
1380
+ return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1381
+ }
1382
+
1383
+ ggml_tensor * llm_graph_context::build_attn(
1384
+ llm_graph_input_attn_kv_unified * inp,
1385
+ ggml_cgraph * gf,
1386
+ ggml_tensor * wo,
1387
+ ggml_tensor * wo_b,
1388
+ ggml_tensor * q_cur,
1389
+ ggml_tensor * k_cur,
1390
+ ggml_tensor * v_cur,
1391
+ ggml_tensor * kq_b,
1392
+ ggml_tensor * v_mla,
1393
+ float kq_scale,
1394
+ int il) const {
1395
+ // these nodes are added to the graph together so that they are not reordered
1396
+ // by doing so, the number of splits in the graph is reduced
1397
+ ggml_build_forward_expand(gf, q_cur);
1398
+ ggml_build_forward_expand(gf, k_cur);
1399
+ ggml_build_forward_expand(gf, v_cur);
1400
+
1401
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1402
+ const auto & n_ctx = cparams.n_ctx;
1403
+
1404
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1405
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1406
+
1407
+ const auto n_tokens = q_cur->ne[2];
1408
+
1409
+ const bool v_trans = !cparams.flash_attn;
1410
+
1411
+ // store to KV cache
1412
+ {
1413
+ GGML_ASSERT(!kv_self->recurrent);
1414
+
1415
+ const auto kv_head = kv_self->head;
1416
+
1417
+ GGML_ASSERT(kv_self->size == n_ctx);
1418
+
1419
+ ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1420
+ //cb(k_cache_view, "k_cache_view", il);
1421
+
1422
+ // note: storing RoPE-ed version of K in the KV cache
1423
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
1424
+
1425
+ v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1426
+
1427
+ ggml_tensor * v_cache_view = nullptr;
1428
+
1429
+ if (!v_trans) {
1430
+ v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1431
+ } else {
1432
+ // note: the V cache is transposed when not using flash attention
1433
+ v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1434
+ ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1435
+ (kv_head)*ggml_element_size(kv_self->v_l[il]));
1436
+
1437
+ v_cur = ggml_transpose(ctx0, v_cur);
1438
+ }
1439
+ //cb(v_cache_view, "v_cache_view", il);
1440
+
1441
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1442
+ }
1443
+
1444
+ const bool is_swa = hparams.is_swa(il);
1445
+
1446
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1447
+
1448
+ const auto n_kv = kv_self->n;
1449
+
1450
+ const int64_t n_head_kv = hparams.n_head_kv(il);
1451
+
1452
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
1453
+ const auto & n_embd_head_v = hparams.n_embd_head_v;
1454
+
1455
+ ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1456
+ //cb(q, "q", il);
1457
+
1458
+ ggml_tensor * k =
1459
+ ggml_view_3d(ctx0, kv_self->k_l[il],
1460
+ n_embd_head_k, n_kv, n_head_kv,
1461
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1462
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1463
+ 0);
1464
+ //cb(k, "k", il);
1465
+
1466
+ ggml_tensor * v = !v_trans ?
1467
+ ggml_view_3d(ctx0, kv_self->v_l[il],
1468
+ n_embd_head_v, n_kv, n_head_kv,
1469
+ ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1470
+ ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1471
+ 0) :
1472
+ ggml_view_3d(ctx0, kv_self->v_l[il],
1473
+ n_kv, n_embd_head_v, n_head_kv,
1474
+ ggml_element_size(kv_self->v_l[il])*n_ctx,
1475
+ ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1476
+ 0);
1477
+
1478
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1479
+ cb(cur, "kqv_out", il);
1480
+
1481
+ if (wo) {
1482
+ cur = build_lora_mm(wo, cur);
1483
+ }
1484
+
1485
+ if (wo_b) {
1486
+ //cb(cur, "kqv_wo", il);
1487
+ }
1488
+
1489
+ if (wo_b) {
1490
+ cur = ggml_add(ctx0, cur, wo_b);
1491
+ }
1492
+
1493
+ return cur;
1494
+ }
1495
+
1496
+ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1497
+ auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
1498
+
1499
+ const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1500
+
1501
+ inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1502
+ ggml_set_input(inp->cross_kq_mask);
1503
+
1504
+ inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
1505
+
1506
+ return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
1507
+ }
1508
+
1509
+ ggml_tensor * llm_graph_context::build_attn(
1510
+ llm_graph_input_attn_cross * inp,
1511
+ ggml_cgraph * gf,
1512
+ ggml_tensor * wo,
1513
+ ggml_tensor * wo_b,
1514
+ ggml_tensor * q_cur,
1515
+ ggml_tensor * k_cur,
1516
+ ggml_tensor * v_cur,
1517
+ ggml_tensor * kq_b,
1518
+ ggml_tensor * v_mla,
1519
+ float kq_scale,
1520
+ int il) const {
1521
+ // these nodes are added to the graph together so that they are not reordered
1522
+ // by doing so, the number of splits in the graph is reduced
1523
+ ggml_build_forward_expand(gf, q_cur);
1524
+ ggml_build_forward_expand(gf, k_cur);
1525
+ ggml_build_forward_expand(gf, v_cur);
1526
+
1527
+ const auto & kq_mask = inp->get_kq_mask_cross();
1528
+
1529
+ ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1530
+ //cb(q, "q", il);
1531
+
1532
+ ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1533
+ //cb(k, "k", il);
1534
+
1535
+ ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1536
+ //cb(k, "v", il);
1537
+
1538
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1539
+
1540
+ cb(cur, "kqv_out", il);
1541
+
1542
+ if (wo) {
1543
+ cur = build_lora_mm(wo, cur);
1544
+ }
1545
+
1546
+ if (wo_b) {
1547
+ //cb(cur, "kqv_wo", il);
1548
+ }
1549
+
1550
+ if (wo_b) {
1551
+ cur = ggml_add(ctx0, cur, wo_b);
1552
+ }
1553
+
1554
+ return cur;
1555
+ }
1556
+
1557
+ ggml_tensor * llm_graph_context::build_copy_mask_state(
1558
+ ggml_cgraph * gf,
1559
+ ggml_tensor * s,
1560
+ ggml_tensor * state_copy,
1561
+ ggml_tensor * state_mask,
1562
+ int32_t n_state,
1563
+ int32_t n_seqs) const {
1564
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1565
+
1566
+ const auto n_kv = kv_self->n;
1567
+ const auto kv_head = kv_self->head;
1568
+
1569
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1570
+
1571
+ // copy states
1572
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1573
+ // this shrinks the tensors's ne[1] to n_kv
1574
+ states = ggml_get_rows(ctx0, states, state_copy);
1575
+
1576
+ // clear states of sequences which are starting at the beginning of this batch
1577
+ // FIXME: zero-out NANs?
1578
+ states = ggml_mul(ctx0, states, state_mask);
1579
+
1580
+ // copy states which won't be changed further (between n_seqs and n_kv)
1581
+ ggml_build_forward_expand(gf,
1582
+ ggml_cpy(ctx0,
1583
+ ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1584
+ ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1585
+
1586
+ // the part of the states that will be used and modified
1587
+ return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1588
+ }
1589
+
1590
+ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1591
+ ggml_cgraph * gf,
1592
+ ggml_tensor * state_copy,
1593
+ ggml_tensor * state_mask,
1594
+ const llama_ubatch & ubatch,
1595
+ int il) const {
1596
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1597
+
1598
+ const auto token_shift_count = hparams.token_shift_count;
1599
+
1600
+ const int64_t n_seqs = ubatch.n_seqs;
1601
+
1602
+ ggml_tensor * token_shift_all = kv_self->k_l[il];
1603
+
1604
+ ggml_tensor * token_shift = build_copy_mask_state(
1605
+ gf, token_shift_all, state_copy, state_mask,
1606
+ hparams.n_embd_k_s(), n_seqs);
1607
+
1608
+ token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1609
+
1610
+ return token_shift;
1611
+ }
1612
+
1613
+ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1614
+ ggml_tensor * token_shift,
1615
+ const llama_ubatch & ubatch,
1616
+ int il) const {
1617
+ const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1618
+
1619
+ const auto token_shift_count = hparams.token_shift_count;
1620
+ const auto n_embd = hparams.n_embd;
1621
+
1622
+ const int64_t n_seqs = ubatch.n_seqs;
1623
+
1624
+ const auto kv_head = kv_self->head;
1625
+
1626
+ return ggml_cpy(
1627
+ ctx0,
1628
+ ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1629
+ ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1630
+ );
1631
+ }
1632
+
1633
+ void llm_graph_context::build_pooling(
1634
+ ggml_cgraph * gf,
1635
+ ggml_tensor * cls,
1636
+ ggml_tensor * cls_b,
1637
+ ggml_tensor * cls_out,
1638
+ ggml_tensor * cls_out_b) const {
1639
+ if (!cparams.embeddings) {
1640
+ return;
1641
+ }
1642
+
1643
+ ggml_tensor * inp = res->t_embd;
1644
+
1645
+ //// find result_norm tensor for input
1646
+ //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
1647
+ // inp = ggml_graph_node(gf, i);
1648
+ // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
1649
+ // break;
1650
+ // }
1651
+
1652
+ // inp = nullptr;
1653
+ //}
1654
+
1655
+ GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
1656
+
1657
+ ggml_tensor * cur;
1658
+
1659
+ switch (pooling_type) {
1660
+ case LLAMA_POOLING_TYPE_NONE:
1661
+ {
1662
+ cur = inp;
1663
+ } break;
1664
+ case LLAMA_POOLING_TYPE_MEAN:
1665
+ {
1666
+ ggml_tensor * inp_mean = build_inp_mean();
1667
+ cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
1668
+ } break;
1669
+ case LLAMA_POOLING_TYPE_CLS:
1670
+ case LLAMA_POOLING_TYPE_LAST:
1671
+ {
1672
+ ggml_tensor * inp_cls = build_inp_cls();
1673
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
1674
+ } break;
1675
+ case LLAMA_POOLING_TYPE_RANK:
1676
+ {
1677
+ ggml_tensor * inp_cls = build_inp_cls();
1678
+ inp = ggml_get_rows(ctx0, inp, inp_cls);
1679
+
1680
+ // classification head
1681
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1682
+ GGML_ASSERT(cls != nullptr);
1683
+ GGML_ASSERT(cls_b != nullptr);
1684
+
1685
+ cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1686
+ cur = ggml_tanh(ctx0, cur);
1687
+
1688
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1689
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1690
+ if (cls_out) {
1691
+ GGML_ASSERT(cls_out_b != nullptr);
1692
+
1693
+ cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1694
+ }
1695
+ } break;
1696
+ default:
1697
+ {
1698
+ GGML_ABORT("unknown pooling type");
1699
+ }
1700
+ }
1701
+
1702
+ cb(cur, "result_embd_pooled", -1);
1703
+ res->t_embd_pooled = cur;
1704
+
1705
+ ggml_build_forward_expand(gf, cur);
1706
+ }
examples/talk-llama/llama-graph.h ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama-arch.h"
4
+ #include "llama-hparams.h"
5
+ #include "llama-adapter.h"
6
+
7
+ #include <cstdint>
8
+ #include <vector>
9
+ #include <memory>
10
+ #include <set>
11
+ #include <functional>
12
+
13
+ struct ggml_cgraph;
14
+ struct ggml_context;
15
+ struct ggml_tensor;
16
+
17
+ struct llama_ubatch;
18
+ struct llama_cparams;
19
+
20
+ class llama_memory_i;
21
+ class llama_kv_cache_unified;
22
+
23
+ // certain models (typically multi-modal) can produce different types of graphs
24
+ enum llm_graph_type {
25
+ LLM_GRAPH_TYPE_DEFAULT,
26
+ LLM_GRAPH_TYPE_ENCODER,
27
+ LLM_GRAPH_TYPE_DECODER,
28
+ };
29
+
30
+ enum llm_ffn_op_type {
31
+ LLM_FFN_SILU,
32
+ LLM_FFN_GELU,
33
+ LLM_FFN_RELU,
34
+ LLM_FFN_RELU_SQR,
35
+ LLM_FFN_SWIGLU,
36
+ };
37
+
38
+ enum llm_ffn_gate_type {
39
+ LLM_FFN_SEQ,
40
+ LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
41
+ };
42
+
43
+ enum llm_norm_type {
44
+ LLM_NORM,
45
+ LLM_NORM_RMS,
46
+ LLM_NORM_GROUP,
47
+ };
48
+
49
+ // TODO: tmp - need something better to pass the data from the encoder to the decoder
50
+ struct llama_cross {
51
+ // the output embeddings from the encoder as a ggml tensor
52
+ // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
53
+ // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
54
+ //ggml_tensor * t_embd = nullptr;
55
+
56
+ int64_t n_embd = 0;
57
+ int64_t n_enc = 0;
58
+
59
+ // embeddings data copied to host memory (tmp)
60
+ std::vector<float> v_embd;
61
+
62
+ // needed to construct the cross-attention mask in the decoder
63
+ std::vector<std::set<llama_seq_id>> seq_ids_enc;
64
+ };
65
+
66
+ //
67
+ // llm_graph_input
68
+ //
69
+
70
+ class llm_graph_input_i {
71
+ public:
72
+ virtual ~llm_graph_input_i() = default;
73
+
74
+ virtual void set_input(const llama_ubatch * ubatch) = 0;
75
+ };
76
+
77
+ using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
78
+
79
+
80
+ class llm_graph_input_embd : public llm_graph_input_i {
81
+ public:
82
+ llm_graph_input_embd() = default;
83
+ virtual ~llm_graph_input_embd() = default;
84
+
85
+ void set_input(const llama_ubatch * ubatch) override;
86
+
87
+ ggml_tensor * tokens = nullptr; // I32 [n_batch]
88
+ ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
89
+ };
90
+
91
+ class llm_graph_input_pos : public llm_graph_input_i {
92
+ public:
93
+ llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
94
+ virtual ~llm_graph_input_pos() = default;
95
+
96
+ void set_input(const llama_ubatch * ubatch) override;
97
+
98
+ ggml_tensor * pos = nullptr; // I32 [n_batch]
99
+
100
+ const int64_t n_pos_per_token = 1;
101
+ };
102
+
103
+ // temperature tuning, used by llama4
104
+ class llm_graph_input_attn_temp : public llm_graph_input_i {
105
+ public:
106
+ llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107
+ : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108
+ virtual ~llm_graph_input_attn_temp() = default;
109
+
110
+ void set_input(const llama_ubatch * ubatch) override;
111
+
112
+ ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113
+
114
+ const int64_t n_pos_per_token = 1;
115
+
116
+ const uint32_t n_attn_temp_floor_scale;
117
+ const float f_attn_temp_scale;
118
+ };
119
+
120
+ class llm_graph_input_pos_bucket : public llm_graph_input_i {
121
+ public:
122
+ llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
123
+ virtual ~llm_graph_input_pos_bucket() = default;
124
+
125
+ void set_input(const llama_ubatch * ubatch) override;
126
+
127
+ ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
128
+
129
+ const llama_hparams & hparams;
130
+ };
131
+
132
+ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
+ public:
134
+ llm_graph_input_pos_bucket_kv(
135
+ const llama_hparams & hparams,
136
+ const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
137
+ virtual ~llm_graph_input_pos_bucket_kv() = default;
138
+
139
+ void set_input(const llama_ubatch * ubatch) override;
140
+
141
+ ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
+
143
+ const llama_hparams & hparams;
144
+ const llama_kv_cache_unified * kv_self;
145
+ };
146
+
147
+ class llm_graph_input_out_ids : public llm_graph_input_i {
148
+ public:
149
+ llm_graph_input_out_ids(
150
+ const llama_hparams & hparams,
151
+ const llama_cparams & cparams,
152
+ int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
153
+ virtual ~llm_graph_input_out_ids() = default;
154
+
155
+ void set_input(const llama_ubatch * ubatch) override;
156
+
157
+ ggml_tensor * out_ids; // I32 [n_outputs]
158
+
159
+ const llama_hparams & hparams;
160
+ const llama_cparams & cparams;
161
+
162
+ const int32_t n_outputs;
163
+ };
164
+
165
+ class llm_graph_input_mean : public llm_graph_input_i {
166
+ public:
167
+ llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
168
+ virtual ~llm_graph_input_mean() = default;
169
+
170
+ void set_input(const llama_ubatch * ubatch) override;
171
+
172
+ ggml_tensor * mean; // F32 [n_batch, n_batch]
173
+
174
+ const llama_cparams & cparams;
175
+ };
176
+
177
+ class llm_graph_input_cls : public llm_graph_input_i {
178
+ public:
179
+ llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
180
+ virtual ~llm_graph_input_cls() = default;
181
+
182
+ void set_input(const llama_ubatch * ubatch) override;
183
+
184
+ ggml_tensor * cls; // I32 [n_batch]
185
+
186
+ const llama_cparams & cparams;
187
+ };
188
+
189
+ class llm_graph_input_s_copy : public llm_graph_input_i {
190
+ public:
191
+ llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
192
+ virtual ~llm_graph_input_s_copy() = default;
193
+
194
+ void set_input(const llama_ubatch * ubatch) override;
195
+
196
+ ggml_tensor * s_copy; // I32 [kv_size]
197
+
198
+ const llama_kv_cache_unified * kv_self;
199
+ };
200
+
201
+ class llm_graph_input_s_mask : public llm_graph_input_i {
202
+ public:
203
+ llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
204
+ virtual ~llm_graph_input_s_mask() = default;
205
+
206
+ void set_input(const llama_ubatch * ubatch) override;
207
+
208
+ ggml_tensor * s_mask; // F32 [1, n_kv]
209
+
210
+ const llama_kv_cache_unified * kv_self;
211
+ };
212
+
213
+ class llm_graph_input_cross_embd : public llm_graph_input_i {
214
+ public:
215
+ llm_graph_input_cross_embd(
216
+ const llama_cross * cross) : cross(cross) {}
217
+ virtual ~llm_graph_input_cross_embd() = default;
218
+
219
+ void set_input(const llama_ubatch * ubatch) override;
220
+
221
+ ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
222
+
223
+ const llama_cross * cross;
224
+ };
225
+
226
+ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
227
+ public:
228
+ llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
229
+ hparams(hparams),
230
+ cparams(cparams) {
231
+ }
232
+ ~llm_graph_input_attn_no_cache() = default;
233
+
234
+ void set_input(const llama_ubatch * ubatch) override;
235
+
236
+ ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
237
+
238
+ ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
239
+ ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
240
+
241
+ const llama_hparams & hparams;
242
+ const llama_cparams & cparams;
243
+ };
244
+
245
+ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
246
+ public:
247
+ llm_graph_input_attn_kv_unified(
248
+ const llama_hparams & hparams,
249
+ const llama_cparams & cparams,
250
+ const llama_kv_cache_unified * kv_self) :
251
+ hparams(hparams),
252
+ cparams(cparams),
253
+ kv_self(kv_self) {
254
+ }
255
+ ~llm_graph_input_attn_kv_unified() = default;
256
+
257
+ void set_input(const llama_ubatch * ubatch) override;
258
+
259
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260
+ ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
261
+
262
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
263
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
264
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
265
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
266
+
267
+ const llama_hparams & hparams;
268
+ const llama_cparams & cparams;
269
+
270
+ const llama_kv_cache_unified * kv_self;
271
+ };
272
+
273
+ class llm_graph_input_attn_cross : public llm_graph_input_i {
274
+ public:
275
+ llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
276
+ ~llm_graph_input_attn_cross() = default;
277
+
278
+ void set_input(const llama_ubatch * ubatch) override;
279
+
280
+ ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
281
+
282
+ ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
283
+ ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
284
+
285
+ const llama_cross * cross = nullptr;
286
+ };
287
+
288
+ //
289
+ // llm_graph_result
290
+ //
291
+
292
+ // these objects deliver the result from the graph build process back to the llama_context
293
+ // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
294
+ // specific data, by calling the set_inputs() method
295
+ // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
296
+ // these are used by the llama_context to extact the relevant data, based on the compute parameters
297
+
298
+ class llm_graph_result_i {
299
+ public:
300
+ virtual ~llm_graph_result_i() = default;
301
+
302
+ virtual ggml_tensor * get_logits() = 0;
303
+ virtual ggml_tensor * get_embd() = 0;
304
+ virtual ggml_tensor * get_embd_pooled() = 0;
305
+
306
+ virtual void set_inputs(const llama_ubatch * ubatch) = 0;
307
+ };
308
+
309
+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
310
+
311
+
312
+ class llm_graph_result : public llm_graph_result_i {
313
+ public:
314
+ virtual ~llm_graph_result() = default;
315
+
316
+ ggml_tensor * get_logits() override { return t_logits; }
317
+ ggml_tensor * get_embd() override { return t_embd; }
318
+ ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
319
+
320
+ void set_inputs(const llama_ubatch * ubatch) override {
321
+ for (auto & input : inputs) {
322
+ input->set_input(ubatch);
323
+ }
324
+ }
325
+
326
+ llm_graph_input_i * add_input(llm_graph_input_ptr input) {
327
+ inputs.emplace_back(std::move(input));
328
+ return inputs.back().get();
329
+ }
330
+
331
+ // important graph nodes
332
+ ggml_tensor * t_logits = nullptr;
333
+ ggml_tensor * t_embd = nullptr;
334
+ ggml_tensor * t_embd_pooled = nullptr;
335
+
336
+ std::vector<llm_graph_input_ptr> inputs;
337
+ };
338
+
339
+ //
340
+ // llm_graph_context
341
+ //
342
+
343
+ // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
344
+ using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
345
+
346
+ struct llm_graph_params {
347
+ ggml_context * ctx;
348
+
349
+ const llm_arch arch;
350
+
351
+ const llama_hparams & hparams;
352
+ const llama_cparams & cparams;
353
+ const llama_ubatch & ubatch;
354
+
355
+ ggml_backend_sched * sched;
356
+ ggml_backend * backend_cpu;
357
+
358
+ const llama_adapter_cvec * cvec;
359
+ const llama_adapter_loras * loras;
360
+ const llama_memory_i * memory;
361
+ const llama_cross * cross;
362
+
363
+ int32_t n_outputs;
364
+
365
+ const llm_graph_cb & cb;
366
+ };
367
+
368
+ struct llm_graph_context {
369
+ const llm_arch arch;
370
+
371
+ const llama_hparams & hparams;
372
+ const llama_cparams & cparams;
373
+ const llama_ubatch & ubatch;
374
+
375
+ const int64_t n_embd;
376
+ const int64_t n_layer;
377
+ const int64_t n_rot;
378
+ const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
379
+ const int64_t n_ctx_per_seq;
380
+ const int64_t n_head;
381
+ const int64_t n_head_kv;
382
+ const int64_t n_embd_head_k;
383
+ const int64_t n_embd_k_gqa;
384
+ const int64_t n_embd_head_v;
385
+ const int64_t n_embd_v_gqa;
386
+ const int64_t n_expert;
387
+ const int64_t n_expert_used;
388
+
389
+ const float freq_base;
390
+ const float freq_scale;
391
+ const float ext_factor;
392
+ const float attn_factor;
393
+ const float beta_fast;
394
+ const float beta_slow;
395
+ const float norm_eps;
396
+ const float norm_rms_eps;
397
+
398
+ const int32_t n_tokens;
399
+ const int32_t n_outputs;
400
+ const int32_t n_ctx_orig; // yarn
401
+
402
+ const enum llama_pooling_type pooling_type;
403
+ const enum llama_rope_type rope_type;
404
+
405
+ ggml_context * ctx0 = nullptr;
406
+
407
+ ggml_backend_sched * sched;
408
+
409
+ ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
410
+
411
+ const llama_adapter_cvec * cvec;
412
+ const llama_adapter_loras * loras;
413
+ const llama_memory_i * memory;
414
+ const llama_cross * cross;
415
+
416
+ const llm_graph_cb & cb_func;
417
+
418
+ std::unique_ptr<llm_graph_result> res;
419
+
420
+ llm_graph_context(const llm_graph_params & params);
421
+
422
+ int64_t n_pos_per_token() const;
423
+
424
+ void cb(ggml_tensor * cur, const char * name, int il) const;
425
+
426
+ //
427
+ // common
428
+ //
429
+
430
+ ggml_tensor * build_cvec(
431
+ ggml_tensor * cur,
432
+ int il) const;
433
+
434
+ // do mat_mul, while optionally apply lora
435
+ ggml_tensor * build_lora_mm(
436
+ ggml_tensor * w,
437
+ ggml_tensor * cur) const;
438
+
439
+ // do mat_mul_id, while optionally apply lora
440
+ ggml_tensor * build_lora_mm_id(
441
+ ggml_tensor * w, // ggml_tensor * as
442
+ ggml_tensor * cur, // ggml_tensor * b
443
+ ggml_tensor * ids) const;
444
+
445
+ ggml_tensor * build_norm(
446
+ ggml_tensor * cur,
447
+ ggml_tensor * mw,
448
+ ggml_tensor * mb,
449
+ llm_norm_type type,
450
+ int il) const;
451
+
452
+ ggml_tensor * build_ffn(
453
+ ggml_tensor * cur,
454
+ ggml_tensor * up,
455
+ ggml_tensor * up_b,
456
+ ggml_tensor * up_s,
457
+ ggml_tensor * gate,
458
+ ggml_tensor * gate_b,
459
+ ggml_tensor * gate_s,
460
+ ggml_tensor * down,
461
+ ggml_tensor * down_b,
462
+ ggml_tensor * down_s,
463
+ ggml_tensor * act_scales,
464
+ llm_ffn_op_type type_op,
465
+ llm_ffn_gate_type type_gate,
466
+ int il) const;
467
+
468
+ ggml_tensor * build_moe_ffn(
469
+ ggml_tensor * cur,
470
+ ggml_tensor * gate_inp,
471
+ ggml_tensor * up_exps,
472
+ ggml_tensor * gate_exps,
473
+ ggml_tensor * down_exps,
474
+ ggml_tensor * exp_probs_b,
475
+ int64_t n_expert,
476
+ int64_t n_expert_used,
477
+ llm_ffn_op_type type_op,
478
+ bool norm_w,
479
+ bool scale_w,
480
+ float w_scale,
481
+ llama_expert_gating_func_type gating_op,
482
+ int il) const;
483
+
484
+ //
485
+ // inputs
486
+ //
487
+
488
+ ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
489
+ ggml_tensor * build_inp_pos() const;
490
+ ggml_tensor * build_inp_attn_scale() const;
491
+ ggml_tensor * build_inp_out_ids() const;
492
+ ggml_tensor * build_inp_mean() const;
493
+ ggml_tensor * build_inp_cls() const;
494
+ ggml_tensor * build_inp_s_copy() const;
495
+ ggml_tensor * build_inp_s_mask() const;
496
+
497
+ ggml_tensor * build_inp_cross_embd() const;
498
+ ggml_tensor * build_inp_pos_bucket_enc() const;
499
+ ggml_tensor * build_inp_pos_bucket_dec() const;
500
+ ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
501
+
502
+ //
503
+ // attention
504
+ //
505
+
506
+ ggml_tensor * build_attn_mha(
507
+ ggml_cgraph * gf,
508
+ ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
509
+ ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
510
+ ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
511
+ ggml_tensor * kq_b,
512
+ ggml_tensor * kq_mask,
513
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
514
+ bool v_trans,
515
+ float kq_scale) const;
516
+
517
+ llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
518
+
519
+ ggml_tensor * build_attn(
520
+ llm_graph_input_attn_no_cache * inp,
521
+ ggml_cgraph * gf,
522
+ ggml_tensor * wo,
523
+ ggml_tensor * wo_b,
524
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
525
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
526
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
527
+ ggml_tensor * kq_b,
528
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
529
+ float kq_scale,
530
+ int il) const;
531
+
532
+ llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
533
+
534
+ ggml_tensor * build_attn(
535
+ llm_graph_input_attn_kv_unified * inp,
536
+ ggml_cgraph * gf,
537
+ ggml_tensor * wo,
538
+ ggml_tensor * wo_b,
539
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
540
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
541
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
542
+ ggml_tensor * kq_b,
543
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
544
+ float kq_scale,
545
+ int il) const;
546
+
547
+ llm_graph_input_attn_cross * build_attn_inp_cross() const;
548
+
549
+ ggml_tensor * build_attn(
550
+ llm_graph_input_attn_cross * inp,
551
+ ggml_cgraph * gf,
552
+ ggml_tensor * wo,
553
+ ggml_tensor * wo_b,
554
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
555
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
556
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
557
+ ggml_tensor * kq_b,
558
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
559
+ float kq_scale,
560
+ int il) const;
561
+
562
+ //
563
+ // recurrent
564
+ //
565
+
566
+ ggml_tensor * build_copy_mask_state(
567
+ ggml_cgraph * gf,
568
+ ggml_tensor * s,
569
+ ggml_tensor * state_copy,
570
+ ggml_tensor * state_mask,
571
+ int32_t n_state,
572
+ int32_t n_seqs) const;
573
+
574
+ ggml_tensor * build_rwkv_token_shift_load(
575
+ ggml_cgraph * gf,
576
+ ggml_tensor * state_copy,
577
+ ggml_tensor * state_mask,
578
+ const llama_ubatch & ubatch,
579
+ int il) const;
580
+
581
+ ggml_tensor * build_rwkv_token_shift_store(
582
+ ggml_tensor * token_shift,
583
+ const llama_ubatch & ubatch,
584
+ int il) const;
585
+
586
+ //
587
+ // pooling
588
+ //
589
+
590
+ void build_pooling(
591
+ ggml_cgraph * gf,
592
+ ggml_tensor * cls,
593
+ ggml_tensor * cls_b,
594
+ ggml_tensor * cls_out,
595
+ ggml_tensor * cls_out_b) const;
596
+ };
examples/talk-llama/llama-hparams.cpp CHANGED
@@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
69
  // corresponds to Mamba's ssm_states size
70
  return ssm_d_state * ssm_d_inner;
71
  }
 
 
 
 
 
 
 
 
 
69
  // corresponds to Mamba's ssm_states size
70
  return ssm_d_state * ssm_d_inner;
71
  }
72
+
73
+ bool llama_hparams::is_swa(uint32_t il) const {
74
+ if (il < n_layer) {
75
+ return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
76
+ }
77
+
78
+ GGML_ABORT("fatal error");
79
+ }
examples/talk-llama/llama-hparams.h CHANGED
@@ -36,12 +36,17 @@ struct llama_hparams {
36
  uint32_t n_layer;
37
  uint32_t n_rot;
38
  uint32_t n_swa = 0; // sliding window attention (SWA)
 
39
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
40
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
41
  uint32_t n_expert = 0;
42
  uint32_t n_expert_used = 0;
43
  uint32_t n_rel_attn_bkts = 0;
44
 
 
 
 
 
45
  // for WavTokenizer
46
  struct llama_hparams_posnet posnet;
47
  struct llama_hparams_convnext convnext;
@@ -75,10 +80,16 @@ struct llama_hparams {
75
  uint32_t time_decay_extra_dim = 0;
76
  uint32_t wkv_head_size = 0;
77
  uint32_t token_shift_count = 2;
 
 
 
 
78
 
79
  float rope_attn_factor = 1.0f;
80
  float rope_freq_base_train;
 
81
  float rope_freq_scale_train;
 
82
  uint32_t n_ctx_orig_yarn;
83
  float rope_yarn_log_mul;
84
 
@@ -105,6 +116,14 @@ struct llama_hparams {
105
  bool use_alibi = false;
106
  bool attn_soft_cap = false;
107
 
 
 
 
 
 
 
 
 
108
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
109
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
110
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -133,6 +152,8 @@ struct llama_hparams {
133
 
134
  // dimension of the recurrent state embeddings
135
  uint32_t n_embd_v_s() const;
 
 
136
  };
137
 
138
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
 
36
  uint32_t n_layer;
37
  uint32_t n_rot;
38
  uint32_t n_swa = 0; // sliding window attention (SWA)
39
+ uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
40
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
41
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
42
  uint32_t n_expert = 0;
43
  uint32_t n_expert_used = 0;
44
  uint32_t n_rel_attn_bkts = 0;
45
 
46
+ // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
47
+ uint32_t n_embd_head_k_mla = 0;
48
+ uint32_t n_embd_head_v_mla = 0;
49
+
50
  // for WavTokenizer
51
  struct llama_hparams_posnet posnet;
52
  struct llama_hparams_convnext convnext;
 
80
  uint32_t time_decay_extra_dim = 0;
81
  uint32_t wkv_head_size = 0;
82
  uint32_t token_shift_count = 2;
83
+ uint32_t n_lora_decay = 0;
84
+ uint32_t n_lora_iclr = 0;
85
+ uint32_t n_lora_value_res_mix = 0;
86
+ uint32_t n_lora_gate = 0;
87
 
88
  float rope_attn_factor = 1.0f;
89
  float rope_freq_base_train;
90
+ float rope_freq_base_train_swa;
91
  float rope_freq_scale_train;
92
+ float rope_freq_scale_train_swa;
93
  uint32_t n_ctx_orig_yarn;
94
  float rope_yarn_log_mul;
95
 
 
116
  bool use_alibi = false;
117
  bool attn_soft_cap = false;
118
 
119
+ uint32_t n_moe_layer_step = 0;
120
+ bool use_kq_norm = true;
121
+ uint32_t n_attn_chunk = 0;
122
+ // values below seems to be fixed on llama4
123
+ uint32_t n_no_rope_layer_step = 4;
124
+ uint32_t n_attn_temp_floor_scale = 8192;
125
+ float f_attn_temp_scale = 0.1;
126
+
127
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
128
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
129
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
 
152
 
153
  // dimension of the recurrent state embeddings
154
  uint32_t n_embd_v_s() const;
155
+
156
+ bool is_swa(uint32_t il) const;
157
  };
158
 
159
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
examples/talk-llama/llama-impl.h CHANGED
@@ -6,13 +6,13 @@
6
  #include <vector>
7
 
8
  #ifdef __GNUC__
9
- #ifdef __MINGW32__
10
- #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
 
 
 
11
  #else
12
- #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
13
- #endif
14
- #else
15
- #define LLAMA_ATTRIBUTE_FORMAT(...)
16
  #endif
17
 
18
  //
 
6
  #include <vector>
7
 
8
  #ifdef __GNUC__
9
+ # if defined(__MINGW32__) && !defined(__clang__)
10
+ # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
11
+ # else
12
+ # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
13
+ # endif
14
  #else
15
+ # define LLAMA_ATTRIBUTE_FORMAT(...)
 
 
 
16
  #endif
17
 
18
  //
examples/talk-llama/llama-io.cpp ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "llama-io.h"
2
+
3
+ void llama_io_write_i::write_string(const std::string & str) {
4
+ uint32_t str_size = str.size();
5
+
6
+ write(&str_size, sizeof(str_size));
7
+ write(str.data(), str_size);
8
+ }
9
+
10
+ void llama_io_read_i::read_string(std::string & str) {
11
+ uint32_t str_size;
12
+ read_to(&str_size, sizeof(str_size));
13
+
14
+ str.assign((const char *) read(str_size), str_size);
15
+ }
examples/talk-llama/llama-io.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstddef>
4
+ #include <cstdint>
5
+ #include <string>
6
+
7
+ struct ggml_tensor;
8
+
9
+ class llama_io_write_i {
10
+ public:
11
+ llama_io_write_i() = default;
12
+ virtual ~llama_io_write_i() = default;
13
+
14
+ virtual void write(const void * src, size_t size) = 0;
15
+ virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
16
+
17
+ // bytes written so far
18
+ virtual size_t n_bytes() = 0;
19
+
20
+ void write_string(const std::string & str);
21
+ };
22
+
23
+ class llama_io_read_i {
24
+ public:
25
+ llama_io_read_i() = default;
26
+ virtual ~llama_io_read_i() = default;
27
+
28
+ virtual const uint8_t * read(size_t size) = 0;
29
+ virtual void read_to(void * dst, size_t size) = 0;
30
+
31
+ // bytes read so far
32
+ virtual size_t n_bytes() = 0;
33
+
34
+ void read_string(std::string & str);
35
+ };
examples/talk-llama/llama-kv-cache.cpp CHANGED
@@ -6,86 +6,90 @@
6
  #include "llama-model.h"
7
 
8
  #include <algorithm>
 
9
  #include <limits>
10
  #include <map>
 
11
 
12
- static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
13
-
14
- uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
15
- // the FA kernels require padding to avoid extra runtime boundary checks
16
- return cparams.flash_attn ? 256u : 32u;
17
  }
18
 
19
- bool llama_kv_cache_init(
20
- struct llama_kv_cache & cache,
21
- const llama_model & model,
22
- const llama_cparams & cparams,
23
- ggml_type type_k,
24
- ggml_type type_v,
25
- uint32_t kv_size,
26
- bool offload) {
27
- const struct llama_hparams & hparams = model.hparams;
28
-
29
  const int32_t n_layer = hparams.n_layer;
30
 
31
- cache.has_shift = false;
32
 
33
- cache.recurrent = llama_model_is_recurrent(&model);
34
- cache.v_trans = !cache.recurrent && !cparams.flash_attn;
35
- cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
36
 
37
  LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
38
- __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
39
 
40
- cache.head = 0;
41
- cache.size = kv_size;
42
- cache.used = 0;
43
 
44
- cache.type_k = type_k;
45
- cache.type_v = type_v;
46
 
47
- cache.cells.clear();
48
- cache.cells.resize(kv_size);
49
 
50
  // create a context for each buffer type
51
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
52
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
53
  auto it = ctx_map.find(buft);
54
  if (it == ctx_map.end()) {
55
- struct ggml_init_params params = {
56
  /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
57
  /*.mem_buffer =*/ NULL,
58
  /*.no_alloc =*/ true,
59
  };
 
60
  ggml_context * ctx = ggml_init(params);
61
  if (!ctx) {
62
  return nullptr;
63
  }
 
64
  ctx_map[buft] = ctx;
65
- cache.ctxs.emplace_back(ctx);
 
66
  return ctx;
67
  }
 
68
  return it->second;
69
  };
70
 
71
- cache.k_l.reserve(n_layer);
72
- cache.v_l.reserve(n_layer);
73
 
74
  for (int i = 0; i < n_layer; i++) {
75
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
76
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
77
 
78
- LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
79
 
80
  ggml_backend_buffer_type_t buft;
81
  if (offload) {
82
  auto * dev = model.dev_layer(i);
83
  buft = ggml_backend_dev_buffer_type(dev);
 
 
84
  } else {
85
  buft = ggml_backend_cpu_buffer_type();
86
  }
87
- ggml_context * ctx = ctx_for_buft(buft);
88
 
 
 
 
 
89
  if (!ctx) {
90
  LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
91
  return false;
@@ -95,8 +99,8 @@ bool llama_kv_cache_init(
95
  ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
96
  ggml_format_name(k, "cache_k_l%d", i);
97
  ggml_format_name(v, "cache_v_l%d", i);
98
- cache.k_l.push_back(k);
99
- cache.v_l.push_back(v);
100
  }
101
 
102
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -111,20 +115,403 @@ bool llama_kv_cache_init(
111
  }
112
  ggml_backend_buffer_clear(buf, 0);
113
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
114
- cache.bufs.emplace_back(buf);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  }
116
 
117
  return true;
118
  }
119
 
120
- struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
121
- struct llama_kv_cache & cache,
122
- const struct llama_ubatch & ubatch) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  const uint32_t n_tokens = ubatch.n_tokens;
124
  const uint32_t n_seqs = ubatch.n_seqs;
125
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
126
 
127
- if (cache.recurrent) {
 
 
 
 
 
 
128
  // For recurrent state architectures (like Mamba or RWKV),
129
  // each cache cell can store the state for a whole sequence.
130
  // A slot should be always be contiguous.
@@ -132,7 +519,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
132
  // can only process batches with an equal number of new tokens in each sequence
133
  GGML_ASSERT(ubatch.equal_seqs);
134
 
135
- int32_t min = cache.size - 1;
136
  int32_t max = 0;
137
 
138
  // everything should fit if all seq_ids are smaller than the max
@@ -141,16 +528,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
141
  for (uint32_t j = 0; j < n_seq_id; ++j) {
142
  const llama_seq_id seq_id = ubatch.seq_id[s][j];
143
 
144
- if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
145
  // too big seq_id
146
  // TODO: would it be possible to resize the cache instead?
147
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
148
- return llama_kv_cache_slot_info_failed;
149
  }
150
  if (j > 0) {
151
- llama_kv_cell & seq = cache.cells[seq_id];
152
  if (seq.tail >= 0) {
153
- llama_kv_cell & cell = cache.cells[seq.tail];
154
  // clear cells from seq_ids that become shared
155
  // (should not normally happen, but let's handle it anyway)
156
  cell.seq_id.erase(seq_id);
@@ -158,7 +545,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
158
  if (cell.seq_id.empty()) {
159
  cell.pos = -1;
160
  cell.src = -1;
161
- cache.used -= 1;
162
  }
163
  }
164
  }
@@ -168,9 +555,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
168
  #ifndef NDEBUG
169
  {
170
  std::vector<int32_t> tails_verif;
171
- tails_verif.assign(cache.size, -1);
172
- for (uint32_t i = 0; i < cache.size; ++i) {
173
- llama_kv_cell & cell = cache.cells[i];
174
  for (llama_seq_id seq_id : cell.seq_id) {
175
  if (tails_verif[seq_id] != -1) {
176
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -178,20 +565,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
178
  tails_verif[seq_id] = i;
179
  }
180
  }
181
- for (uint32_t i = 0; i < cache.size; ++i) {
182
- if (tails_verif[i] != cache.cells[i].tail) {
183
- LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
184
  }
185
  }
186
  }
187
  #endif
188
 
189
  // find next empty cell
190
- uint32_t next_empty_cell = cache.head;
191
 
192
- for (uint32_t i = 0; i < cache.size; ++i) {
193
- if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
194
- llama_kv_cell & cell = cache.cells[next_empty_cell];
195
  if (cell.is_empty()) { break; }
196
  next_empty_cell += 1;
197
  }
@@ -199,20 +586,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
199
  // find usable cell range
200
  for (uint32_t s = 0; s < n_seqs; ++s) {
201
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
202
- llama_kv_cell & seq_meta = cache.cells[seq_id];
203
  bool has_cell = false;
204
  if (seq_meta.tail >= 0) {
205
- llama_kv_cell & cell = cache.cells[seq_meta.tail];
206
  GGML_ASSERT(cell.has_seq_id(seq_id));
207
  // does this seq_id "own" the cell?
208
  if (cell.seq_id.size() == 1) { has_cell = true; }
209
  }
210
  if (!has_cell) {
211
- llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
212
  GGML_ASSERT(empty_cell.is_empty());
213
  // copy old tail into the empty cell
214
  if (seq_meta.tail >= 0) {
215
- llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
216
  empty_cell.pos = orig_cell.pos;
217
  empty_cell.src = orig_cell.src;
218
  orig_cell.seq_id.erase(seq_id);
@@ -222,9 +609,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
222
  // find next empty cell
223
  if (s + 1 < n_seqs) {
224
  next_empty_cell += 1;
225
- for (uint32_t i = 0; i < cache.size; ++i) {
226
- if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
227
- llama_kv_cell & cell = cache.cells[next_empty_cell];
228
  if (cell.is_empty()) { break; }
229
  next_empty_cell += 1;
230
  }
@@ -237,10 +624,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237
  // gather and re-order
238
  for (uint32_t s = 0; s < n_seqs; ++s) {
239
  int32_t dst_id = s + min;
240
- int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
241
  if (dst_id != src_id) {
242
- llama_kv_cell & dst_cell = cache.cells[dst_id];
243
- llama_kv_cell & src_cell = cache.cells[src_id];
244
 
245
  std::swap(dst_cell.pos, src_cell.pos);
246
  std::swap(dst_cell.src, src_cell.src);
@@ -248,10 +635,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
248
 
249
  // swap tails (assuming they NEVER overlap)
250
  for (const llama_seq_id seq_id : src_cell.seq_id) {
251
- cache.cells[seq_id].tail = src_id;
252
  }
253
  for (const llama_seq_id seq_id : dst_cell.seq_id) {
254
- cache.cells[seq_id].tail = dst_id;
255
  }
256
  }
257
  }
@@ -260,7 +647,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
260
  for (uint32_t s = 0; s < n_seqs; ++s) {
261
  const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
262
  int32_t cell_id = s + min;
263
- llama_kv_cell & cell = cache.cells[cell_id];
264
 
265
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266
  // What should happen when the pos backtracks or skips a value?
@@ -273,41 +660,42 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
273
  for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
274
  const llama_seq_id seq_id = ubatch.seq_id[s][j];
275
  cell.seq_id.insert(seq_id);
276
- cache.cells[seq_id].tail = cell_id;
277
  }
278
  }
279
 
280
  // allow getting the range of used cells, from head to head + n
281
- cache.head = min;
282
- cache.n = max - min + 1;
283
- cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
284
  [](const llama_kv_cell& cell){ return !cell.is_empty(); });
285
 
286
  // sanity check
287
- return llama_kv_cache_slot_info(cache.n >= n_seqs);
288
  }
 
289
  // otherwise, one cell per token.
290
 
291
- if (n_tokens > cache.size) {
292
- LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
293
- return llama_kv_cache_slot_info_failed;
294
  }
295
 
296
  uint32_t n_tested = 0;
297
 
298
  while (true) {
299
- if (cache.head + n_tokens > cache.size) {
300
- n_tested += cache.size - cache.head;
301
- cache.head = 0;
302
  continue;
303
  }
304
 
305
  bool found = true;
306
  for (uint32_t i = 0; i < n_tokens; i++) {
307
- if (cache.cells[cache.head + i].pos >= 0) {
308
  found = false;
309
- cache.head += i + 1;
310
- n_tested += i + 1;
311
  break;
312
  }
313
  }
@@ -316,31 +704,38 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
316
  break;
317
  }
318
 
319
- if (n_tested >= cache.size) {
320
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
321
- return llama_kv_cache_slot_info_failed;
322
  }
323
  }
324
 
325
  for (uint32_t s = 0; s < n_seqs; s++) {
326
  for (uint32_t i = 0; i < n_seq_tokens; ++i) {
327
  uint32_t k = s*n_seq_tokens + i;
328
- cache.cells[cache.head + k].pos = ubatch.pos[k];
329
 
330
  for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
331
- cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
332
  }
333
  }
334
  }
335
 
336
- cache.used += n_tokens;
 
 
 
 
 
337
 
338
- return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
 
 
339
  }
340
 
341
- uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
342
- for (uint32_t i = cache.size; i > 0; --i) {
343
- const llama_kv_cell & cell = cache.cells[i - 1];
344
 
345
  if (cell.pos >= 0 && !cell.is_empty()) {
346
  return i;
@@ -350,289 +745,549 @@ uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
350
  return 0;
351
  }
352
 
353
- void llama_kv_cache_clear(struct llama_kv_cache & cache) {
354
- for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
355
- cache.cells[i].pos = -1;
356
- cache.cells[i].seq_id.clear();
357
- cache.cells[i].src = -1;
358
- cache.cells[i].tail = -1;
359
  }
360
- cache.head = 0;
361
- cache.used = 0;
362
 
363
- for (auto & buf : cache.bufs) {
364
- ggml_backend_buffer_clear(buf.get(), 0);
 
 
 
 
 
 
365
  }
 
 
366
  }
367
 
368
- bool llama_kv_cache_seq_rm(
369
- struct llama_kv_cache & cache,
370
- llama_seq_id seq_id,
371
- llama_pos p0,
372
- llama_pos p1) {
373
- uint32_t new_head = cache.size;
374
 
375
- if (p0 < 0) p0 = 0;
376
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
377
 
378
- // models like Mamba or RWKV can't have a state partially erased
379
- if (cache.recurrent) {
380
- if (seq_id >= (int64_t) cache.size) {
381
- // could be fatal
382
- return false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  }
384
- if (0 <= seq_id) {
385
- int32_t & tail_id = cache.cells[seq_id].tail;
386
- if (tail_id >= 0) {
387
- const llama_kv_cell & cell = cache.cells[tail_id];
388
- // partial intersection is invalid
389
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
390
- return false;
391
- }
392
- // invalidate tails which will be cleared
393
- if (p0 <= cell.pos && cell.pos < p1) {
394
- tail_id = -1;
395
- }
396
- }
397
- } else {
398
- // seq_id is negative, then the range should include everything or nothing
399
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
400
- return false;
401
- }
402
  }
403
- }
404
 
405
- for (uint32_t i = 0; i < cache.size; ++i) {
406
- if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
407
- if (seq_id < 0) {
408
- cache.cells[i].seq_id.clear();
409
- } else if (cache.cells[i].has_seq_id(seq_id)) {
410
- cache.cells[i].seq_id.erase(seq_id);
411
- } else {
 
412
  continue;
413
  }
414
- if (cache.cells[i].is_empty()) {
415
- // keep count of the number of used cells
416
- if (cache.cells[i].pos >= 0) cache.used--;
417
 
418
- cache.cells[i].pos = -1;
419
- cache.cells[i].src = -1;
420
- if (new_head == cache.size) new_head = i;
 
 
421
  }
422
  }
423
- }
424
 
425
- // If we freed up a slot, set head to it so searching can start there.
426
- if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
427
 
428
- return true;
429
- }
430
 
431
- void llama_kv_cache_seq_cp(
432
- struct llama_kv_cache & cache,
433
- llama_seq_id seq_id_src,
434
- llama_seq_id seq_id_dst,
435
- llama_pos p0,
436
- llama_pos p1) {
437
- if (p0 < 0) p0 = 0;
438
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
439
 
440
- if (cache.recurrent) {
441
- if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
442
- llama_kv_cell & tail_src = cache.cells[seq_id_src];
443
- llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
444
- if (tail_dst.tail >= 0) {
445
- // clear destination seq_id if it wasn't empty
446
- llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
447
 
448
- cell_dst.seq_id.erase(seq_id_dst);
449
- tail_dst.tail = -1;
450
- if (cell_dst.seq_id.empty()) {
451
- cell_dst.pos = -1;
452
- cell_dst.delta = -1;
453
- cell_dst.src = -1;
454
- cache.used -= 1;
 
 
 
 
455
  }
 
 
 
456
  }
457
- if (tail_src.tail >= 0) {
458
- llama_kv_cell & cell_src = cache.cells[tail_src.tail];
459
 
460
- cell_src.seq_id.insert(seq_id_dst);
461
- tail_dst.tail = tail_src.tail;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  }
463
  }
464
 
465
- return;
 
 
 
 
 
 
 
 
 
 
466
  }
467
- // otherwise, this is the KV cache of a Transformer-like model
468
 
469
- cache.head = 0;
470
 
471
- for (uint32_t i = 0; i < cache.size; ++i) {
472
- if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
473
- cache.cells[i].seq_id.insert(seq_id_dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  }
475
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  }
477
 
478
- void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
479
- uint32_t new_head = cache.size;
 
480
 
481
- for (uint32_t i = 0; i < cache.size; ++i) {
482
- if (cache.recurrent && (llama_seq_id) i != seq_id) {
483
- cache.cells[i].tail = -1;
484
- }
485
- if (!cache.cells[i].has_seq_id(seq_id)) {
486
- if (cache.cells[i].pos >= 0) cache.used--;
487
- cache.cells[i].pos = -1;
488
- cache.cells[i].src = -1;
489
- cache.cells[i].seq_id.clear();
490
- if (new_head == cache.size) new_head = i;
491
  } else {
492
- cache.cells[i].seq_id.clear();
493
- cache.cells[i].seq_id.insert(seq_id);
494
  }
 
495
  }
496
-
497
- // If we freed up a slot, set head to it so searching can start there.
498
- if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
499
  }
500
 
501
- void llama_kv_cache_seq_add(
502
- struct llama_kv_cache & cache,
503
- llama_seq_id seq_id,
504
- llama_pos p0,
505
- llama_pos p1,
506
- llama_pos delta) {
507
- uint32_t new_head = cache.size;
508
 
509
- if (p0 < 0) p0 = 0;
510
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
511
- // If there is no range then return early to avoid looping over the cache.
512
- if (p0 == p1) return;
513
 
514
- if (cache.recurrent) {
515
- // for Mamba-like or RWKV models, only the pos needs to be shifted
516
- if (0 <= seq_id && seq_id < (int64_t) cache.size) {
517
- const int32_t tail_id = cache.cells[seq_id].tail;
518
- if (tail_id >= 0) {
519
- llama_kv_cell & cell = cache.cells[tail_id];
520
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
521
- cell.pos += delta;
522
  }
523
  }
524
  }
525
- return;
526
  }
 
527
 
528
- for (uint32_t i = 0; i < cache.size; ++i) {
529
- if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
530
- cache.has_shift = true;
531
- cache.cells[i].pos += delta;
532
- cache.cells[i].delta += delta;
533
 
534
- if (cache.cells[i].pos < 0) {
535
- if (!cache.cells[i].is_empty()) {
536
- cache.used--;
537
- }
538
- cache.cells[i].pos = -1;
539
- cache.cells[i].seq_id.clear();
540
- if (new_head == cache.size) {
541
- new_head = i;
542
- }
543
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  }
545
  }
546
 
547
- // If we freed up a slot, set head to it so searching can start there.
548
- // Otherwise we just start the next search from the beginning.
549
- cache.head = new_head != cache.size ? new_head : 0;
550
- }
551
 
552
- void llama_kv_cache_seq_div(
553
- struct llama_kv_cache & cache,
554
- llama_seq_id seq_id,
555
- llama_pos p0,
556
- llama_pos p1,
557
- int d) {
558
- if (p0 < 0) p0 = 0;
559
- if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
560
- // If there is no range then return early to avoid looping over the cache.
561
- if (p0 == p1) return;
562
 
563
- if (cache.recurrent) {
564
- // for Mamba-like or RWKV models, only the pos needs to be changed
565
- if (0 <= seq_id && seq_id < (int64_t) cache.size) {
566
- const int32_t tail_id = cache.cells[seq_id].tail;
567
- if (tail_id >= 0) {
568
- llama_kv_cell & cell = cache.cells[tail_id];
569
- if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
570
- cell.pos /= d;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  }
572
  }
573
  }
574
- return;
575
  }
 
576
 
577
- for (uint32_t i = 0; i < cache.size; ++i) {
578
- if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
579
- cache.has_shift = true;
580
 
581
- {
582
- llama_pos p_old = cache.cells[i].pos;
583
- cache.cells[i].pos /= d;
584
- cache.cells[i].delta += cache.cells[i].pos - p_old;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  }
587
- }
588
- }
589
 
590
- llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
591
- llama_pos result = 0;
 
 
 
 
 
 
 
 
 
 
592
 
593
- for (uint32_t i = 0; i < cache.size; ++i) {
594
- if (cache.cells[i].has_seq_id(seq_id)) {
595
- result = std::max(result, cache.cells[i].pos);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  }
 
 
 
597
  }
598
 
599
- return result;
 
 
 
 
 
 
 
 
600
  }
601
 
602
- void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
603
- if (!cache.recurrent) {
604
- cache.do_defrag = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  }
606
- }
607
 
608
- int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
609
- int result = 0;
 
610
 
611
- for (uint32_t i = 0; i < kv.size; i++) {
612
- result += kv.cells[i].seq_id.size();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  }
614
 
615
- return result;
616
- }
 
617
 
618
- int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
619
- return kv.used;
620
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
- bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
623
- return kv.can_shift;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  }
625
 
626
  //
627
  // kv cache view
628
  //
629
 
630
- struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
631
- struct llama_kv_cache_view result = {
632
  /*.n_cells = */ 0,
633
  /*.n_seq_max = */ n_seq_max,
634
  /*.token_count = */ 0,
635
- /*.used_cells = */ llama_get_kv_cache_used_cells(kv),
636
  /*.max_contiguous = */ 0,
637
  /*.max_contiguous_idx = */ -1,
638
  /*.cells = */ nullptr,
@@ -642,7 +1297,7 @@ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache
642
  return result;
643
  }
644
 
645
- void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
646
  if (view->cells != nullptr) {
647
  free(view->cells);
648
  view->cells = nullptr;
@@ -653,18 +1308,25 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
653
  }
654
  }
655
 
656
- void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
657
- if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
658
- view->n_cells = int32_t(kv.size);
659
- void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
 
 
 
 
 
 
 
660
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
661
- view->cells = (struct llama_kv_cache_view_cell *)p;
662
  p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
663
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
664
  view->cells_sequences = (llama_seq_id *)p;
665
  }
666
 
667
- const std::vector<llama_kv_cell> & kv_cells = kv.cells;
668
  llama_kv_cache_view_cell * c_curr = view->cells;
669
  llama_seq_id * cs_curr = view->cells_sequences;
670
  int32_t used_cells = 0;
@@ -673,7 +1335,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
673
  uint32_t max_contig = 0;
674
  int32_t max_contig_idx = -1;
675
 
676
- for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
677
  const size_t curr_size = kv_cells[i].seq_id.size();
678
  token_count += curr_size;
679
  c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@@ -711,8 +1373,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
711
  view->max_contiguous_idx = max_contig_idx;
712
  view->token_count = token_count;
713
  view->used_cells = used_cells;
714
- if (uint32_t(used_cells) != kv.used) {
715
  LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
716
- __func__, kv.used, used_cells);
717
  }
718
  }
 
6
  #include "llama-model.h"
7
 
8
  #include <algorithm>
9
+ #include <cassert>
10
  #include <limits>
11
  #include <map>
12
+ #include <stdexcept>
13
 
14
+ llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
 
 
 
 
15
  }
16
 
17
+ bool llama_kv_cache_unified::init(
18
+ const llama_model & model,
19
+ const llama_cparams & cparams,
20
+ ggml_type type_k,
21
+ ggml_type type_v,
22
+ uint32_t kv_size,
23
+ bool offload) {
 
 
 
24
  const int32_t n_layer = hparams.n_layer;
25
 
26
+ has_shift = false;
27
 
28
+ recurrent = llama_model_is_recurrent(&model);
29
+ v_trans = !recurrent && !cparams.flash_attn;
30
+ can_shift = !recurrent;
31
 
32
  LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
33
+ __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
34
 
35
+ head = 0;
36
+ size = kv_size;
37
+ used = 0;
38
 
39
+ this->type_k = type_k;
40
+ this->type_v = type_v;
41
 
42
+ cells.clear();
43
+ cells.resize(kv_size);
44
 
45
  // create a context for each buffer type
46
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
47
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
48
  auto it = ctx_map.find(buft);
49
  if (it == ctx_map.end()) {
50
+ ggml_init_params params = {
51
  /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
52
  /*.mem_buffer =*/ NULL,
53
  /*.no_alloc =*/ true,
54
  };
55
+
56
  ggml_context * ctx = ggml_init(params);
57
  if (!ctx) {
58
  return nullptr;
59
  }
60
+
61
  ctx_map[buft] = ctx;
62
+ ctxs.emplace_back(ctx);
63
+
64
  return ctx;
65
  }
66
+
67
  return it->second;
68
  };
69
 
70
+ k_l.reserve(n_layer);
71
+ v_l.reserve(n_layer);
72
 
73
  for (int i = 0; i < n_layer; i++) {
74
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
75
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
76
 
77
+ const char * dev_name = "CPU";
78
 
79
  ggml_backend_buffer_type_t buft;
80
  if (offload) {
81
  auto * dev = model.dev_layer(i);
82
  buft = ggml_backend_dev_buffer_type(dev);
83
+
84
+ dev_name = ggml_backend_dev_name(dev);
85
  } else {
86
  buft = ggml_backend_cpu_buffer_type();
87
  }
 
88
 
89
+ LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
90
+ i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
91
+
92
+ ggml_context * ctx = ctx_for_buft(buft);
93
  if (!ctx) {
94
  LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
95
  return false;
 
99
  ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
100
  ggml_format_name(k, "cache_k_l%d", i);
101
  ggml_format_name(v, "cache_v_l%d", i);
102
+ k_l.push_back(k);
103
+ v_l.push_back(v);
104
  }
105
 
106
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
 
115
  }
116
  ggml_backend_buffer_clear(buf, 0);
117
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
118
+ bufs.emplace_back(buf);
119
+ }
120
+
121
+ return true;
122
+ }
123
+
124
+ int32_t llama_kv_cache_unified::get_n_tokens() const {
125
+ int32_t result = 0;
126
+
127
+ for (uint32_t i = 0; i < size; i++) {
128
+ result += cells[i].seq_id.size();
129
+ }
130
+
131
+ return result;
132
+ }
133
+
134
+ int32_t llama_kv_cache_unified::get_used_cells() const {
135
+ return used;
136
+ }
137
+
138
+ size_t llama_kv_cache_unified::total_size() const {
139
+ size_t size = 0;
140
+ for (const auto & buf : bufs) {
141
+ size += ggml_backend_buffer_get_size(buf.get());
142
+ }
143
+
144
+ return size;
145
+ }
146
+
147
+ llama_pos llama_kv_cache_unified::pos_max() const {
148
+ llama_pos pos_max = -1;
149
+ for (const auto & cell : cells) {
150
+ pos_max = std::max(pos_max, cell.pos);
151
+ }
152
+
153
+ return pos_max;
154
+ }
155
+
156
+ void llama_kv_cache_unified::clear() {
157
+ for (int32_t i = 0; i < (int32_t) size; ++i) {
158
+ cells[i].pos = -1;
159
+ cells[i].seq_id.clear();
160
+ cells[i].src = -1;
161
+ cells[i].tail = -1;
162
+ }
163
+ head = 0;
164
+ used = 0;
165
+
166
+ for (auto & buf : bufs) {
167
+ ggml_backend_buffer_clear(buf.get(), 0);
168
+ }
169
+ }
170
+
171
+ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
172
+ uint32_t new_head = size;
173
+
174
+ if (p0 < 0) {
175
+ p0 = 0;
176
+ }
177
+
178
+ if (p1 < 0) {
179
+ p1 = std::numeric_limits<llama_pos>::max();
180
+ }
181
+
182
+ // models like Mamba or RWKV can't have a state partially erased
183
+ if (recurrent) {
184
+ if (seq_id >= (int64_t) size) {
185
+ // could be fatal
186
+ return false;
187
+ }
188
+ if (0 <= seq_id) {
189
+ int32_t & tail_id = cells[seq_id].tail;
190
+ if (tail_id >= 0) {
191
+ const llama_kv_cell & cell = cells[tail_id];
192
+ // partial intersection is invalid
193
+ if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
194
+ return false;
195
+ }
196
+ // invalidate tails which will be cleared
197
+ if (p0 <= cell.pos && cell.pos < p1) {
198
+ tail_id = -1;
199
+ }
200
+ }
201
+ } else {
202
+ // seq_id is negative, then the range should include everything or nothing
203
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
204
+ return false;
205
+ }
206
+ }
207
+
208
+ return true;
209
+ }
210
+
211
+ for (uint32_t i = 0; i < size; ++i) {
212
+ if (cells[i].pos >= p0 && cells[i].pos < p1) {
213
+ if (seq_id < 0) {
214
+ cells[i].seq_id.clear();
215
+ } else if (cells[i].has_seq_id(seq_id)) {
216
+ cells[i].seq_id.erase(seq_id);
217
+ } else {
218
+ continue;
219
+ }
220
+ if (cells[i].is_empty()) {
221
+ // keep count of the number of used cells
222
+ if (cells[i].pos >= 0) {
223
+ used--;
224
+ }
225
+
226
+ cells[i].pos = -1;
227
+ cells[i].src = -1;
228
+
229
+ if (new_head == size) {
230
+ new_head = i;
231
+ }
232
+ }
233
+ }
234
+ }
235
+
236
+ // If we freed up a slot, set head to it so searching can start there.
237
+ if (new_head != size && new_head < head) {
238
+ head = new_head;
239
  }
240
 
241
  return true;
242
  }
243
 
244
+ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
245
+ if (seq_id_src == seq_id_dst) {
246
+ return;
247
+ }
248
+
249
+ if (p0 < 0) {
250
+ p0 = 0;
251
+ }
252
+
253
+ if (p1 < 0) {
254
+ p1 = std::numeric_limits<llama_pos>::max();
255
+ }
256
+
257
+ if (recurrent) {
258
+ if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
259
+ llama_kv_cell & tail_src = cells[seq_id_src];
260
+ llama_kv_cell & tail_dst = cells[seq_id_dst];
261
+ if (tail_dst.tail >= 0) {
262
+ // clear destination seq_id if it wasn't empty
263
+ llama_kv_cell & cell_dst = cells[tail_dst.tail];
264
+
265
+ cell_dst.seq_id.erase(seq_id_dst);
266
+ tail_dst.tail = -1;
267
+ if (cell_dst.seq_id.empty()) {
268
+ cell_dst.pos = -1;
269
+ cell_dst.delta = -1;
270
+ cell_dst.src = -1;
271
+ used -= 1;
272
+ }
273
+ }
274
+ if (tail_src.tail >= 0) {
275
+ llama_kv_cell & cell_src = cells[tail_src.tail];
276
+
277
+ cell_src.seq_id.insert(seq_id_dst);
278
+ tail_dst.tail = tail_src.tail;
279
+ }
280
+ }
281
+
282
+ return;
283
+ }
284
+
285
+ // otherwise, this is the KV of a Transformer-like model
286
+ head = 0;
287
+
288
+ for (uint32_t i = 0; i < size; ++i) {
289
+ if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
290
+ cells[i].seq_id.insert(seq_id_dst);
291
+ }
292
+ }
293
+ }
294
+
295
+ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
296
+ uint32_t new_head = size;
297
+
298
+ for (uint32_t i = 0; i < size; ++i) {
299
+ if (recurrent && (llama_seq_id) i != seq_id) {
300
+ cells[i].tail = -1;
301
+ }
302
+
303
+ if (!cells[i].has_seq_id(seq_id)) {
304
+ if (cells[i].pos >= 0) {
305
+ used--;
306
+ }
307
+
308
+ cells[i].pos = -1;
309
+ cells[i].src = -1;
310
+ cells[i].seq_id.clear();
311
+
312
+ if (new_head == size){
313
+ new_head = i;
314
+ }
315
+ } else {
316
+ cells[i].seq_id.clear();
317
+ cells[i].seq_id.insert(seq_id);
318
+ }
319
+ }
320
+
321
+ // If we freed up a slot, set head to it so searching can start there.
322
+ if (new_head != size && new_head < head) {
323
+ head = new_head;
324
+ }
325
+ }
326
+
327
+ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
328
+ if (delta == 0) {
329
+ return;
330
+ }
331
+
332
+ uint32_t new_head = size;
333
+
334
+ if (p0 < 0) {
335
+ p0 = 0;
336
+ }
337
+
338
+ if (p1 < 0) {
339
+ p1 = std::numeric_limits<llama_pos>::max();
340
+ }
341
+
342
+ // If there is no range then return early to avoid looping over the
343
+ if (p0 == p1) {
344
+ return;
345
+ }
346
+
347
+ if (recurrent) {
348
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
349
+ if (0 <= seq_id && seq_id < (int64_t) size) {
350
+ const int32_t tail_id = cells[seq_id].tail;
351
+ if (tail_id >= 0) {
352
+ llama_kv_cell & cell = cells[tail_id];
353
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
354
+ cell.pos += delta;
355
+ }
356
+ }
357
+ }
358
+ return;
359
+ }
360
+
361
+ for (uint32_t i = 0; i < size; ++i) {
362
+ if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
363
+ has_shift = true;
364
+ cells[i].pos += delta;
365
+ cells[i].delta += delta;
366
+
367
+ if (cells[i].pos < 0) {
368
+ if (!cells[i].is_empty()) {
369
+ used--;
370
+ }
371
+ cells[i].pos = -1;
372
+ cells[i].seq_id.clear();
373
+ if (new_head == size) {
374
+ new_head = i;
375
+ }
376
+ }
377
+ }
378
+ }
379
+
380
+ // If we freed up a slot, set head to it so searching can start there.
381
+ // Otherwise we just start the next search from the beginning.
382
+ head = new_head != size ? new_head : 0;
383
+ }
384
+
385
+ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
386
+ if (d == 1) {
387
+ return;
388
+ }
389
+
390
+ if (p0 < 0) {
391
+ p0 = 0;
392
+ }
393
+
394
+ if (p1 < 0) {
395
+ p1 = std::numeric_limits<llama_pos>::max();
396
+ }
397
+
398
+ // If there is no range then return early to avoid looping over the cache.
399
+ if (p0 == p1) {
400
+ return;
401
+ }
402
+
403
+ if (recurrent) {
404
+ // for Mamba-like or RWKV models, only the pos needs to be changed
405
+ if (0 <= seq_id && seq_id < (int64_t) size) {
406
+ const int32_t tail_id = cells[seq_id].tail;
407
+ if (tail_id >= 0) {
408
+ llama_kv_cell & cell = cells[tail_id];
409
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
410
+ cell.pos /= d;
411
+ }
412
+ }
413
+ }
414
+
415
+ return;
416
+ }
417
+
418
+ for (uint32_t i = 0; i < size; ++i) {
419
+ if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
420
+ has_shift = true;
421
+
422
+ {
423
+ llama_pos p_old = cells[i].pos;
424
+ cells[i].pos /= d;
425
+ cells[i].delta += cells[i].pos - p_old;
426
+ }
427
+ }
428
+ }
429
+ }
430
+
431
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
432
+ llama_pos result = 0;
433
+
434
+ for (uint32_t i = 0; i < size; ++i) {
435
+ if (cells[i].has_seq_id(seq_id)) {
436
+ result = std::max(result, cells[i].pos);
437
+ }
438
+ }
439
+
440
+ return result;
441
+ }
442
+
443
+ void llama_kv_cache_unified::defrag() {
444
+ if (!recurrent) {
445
+ do_defrag = true;
446
+ }
447
+ }
448
+
449
+ void llama_kv_cache_unified::restore() {
450
+ if (pending.ranges.empty()) {
451
+ return;
452
+ }
453
+
454
+ // TODO: tmp - move to llama_kv_cache_recurrent
455
+ if (recurrent) {
456
+ seq_rm(-1, -1, -1);
457
+ return;
458
+ }
459
+
460
+ uint32_t new_head = size;
461
+
462
+ for (auto & range : pending.ranges) {
463
+ for (uint32_t i = range.c0; i < range.c1; ++i) {
464
+ cells[i].seq_id.clear();
465
+
466
+ // keep count of the number of used cells
467
+ if (cells[i].pos >= 0) {
468
+ used--;
469
+ }
470
+
471
+ cells[i].pos = -1;
472
+ cells[i].src = -1;
473
+ }
474
+
475
+ new_head = std::min(new_head, range.c0);
476
+ }
477
+
478
+ if (new_head != size && new_head < head) {
479
+ head = new_head;
480
+ }
481
+ }
482
+
483
+ void llama_kv_cache_unified::commit() {
484
+ // TODO: tmp - move to llama_kv_cache_recurrent
485
+ if (recurrent) {
486
+ return;
487
+ }
488
+
489
+ if (pending.ranges.empty()) {
490
+ LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
491
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
492
+ return;
493
+ }
494
+
495
+ pending.ranges.clear();
496
+ }
497
+
498
+ bool llama_kv_cache_unified::get_can_shift() const {
499
+ return can_shift;
500
+ }
501
+
502
+ bool llama_kv_cache_unified::find_slot(
503
+ const llama_ubatch & ubatch) {
504
  const uint32_t n_tokens = ubatch.n_tokens;
505
  const uint32_t n_seqs = ubatch.n_seqs;
506
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
507
 
508
+ // if we have enough unused cells before the current head ->
509
+ // better to start searching from the beginning of the cache, hoping to fill it
510
+ if (head > used + 2*ubatch.n_tokens) {
511
+ head = 0;
512
+ }
513
+
514
+ if (recurrent) {
515
  // For recurrent state architectures (like Mamba or RWKV),
516
  // each cache cell can store the state for a whole sequence.
517
  // A slot should be always be contiguous.
 
519
  // can only process batches with an equal number of new tokens in each sequence
520
  GGML_ASSERT(ubatch.equal_seqs);
521
 
522
+ int32_t min = size - 1;
523
  int32_t max = 0;
524
 
525
  // everything should fit if all seq_ids are smaller than the max
 
528
  for (uint32_t j = 0; j < n_seq_id; ++j) {
529
  const llama_seq_id seq_id = ubatch.seq_id[s][j];
530
 
531
+ if (seq_id < 0 || (uint32_t) seq_id >= size) {
532
  // too big seq_id
533
  // TODO: would it be possible to resize the cache instead?
534
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
535
+ return false;
536
  }
537
  if (j > 0) {
538
+ llama_kv_cell & seq = cells[seq_id];
539
  if (seq.tail >= 0) {
540
+ llama_kv_cell & cell = cells[seq.tail];
541
  // clear cells from seq_ids that become shared
542
  // (should not normally happen, but let's handle it anyway)
543
  cell.seq_id.erase(seq_id);
 
545
  if (cell.seq_id.empty()) {
546
  cell.pos = -1;
547
  cell.src = -1;
548
+ used -= 1;
549
  }
550
  }
551
  }
 
555
  #ifndef NDEBUG
556
  {
557
  std::vector<int32_t> tails_verif;
558
+ tails_verif.assign(size, -1);
559
+ for (uint32_t i = 0; i < size; ++i) {
560
+ llama_kv_cell & cell = cells[i];
561
  for (llama_seq_id seq_id : cell.seq_id) {
562
  if (tails_verif[seq_id] != -1) {
563
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
 
565
  tails_verif[seq_id] = i;
566
  }
567
  }
568
+ for (uint32_t i = 0; i < size; ++i) {
569
+ if (tails_verif[i] != cells[i].tail) {
570
+ LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
571
  }
572
  }
573
  }
574
  #endif
575
 
576
  // find next empty cell
577
+ uint32_t next_empty_cell = head;
578
 
579
+ for (uint32_t i = 0; i < size; ++i) {
580
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
581
+ llama_kv_cell & cell = cells[next_empty_cell];
582
  if (cell.is_empty()) { break; }
583
  next_empty_cell += 1;
584
  }
 
586
  // find usable cell range
587
  for (uint32_t s = 0; s < n_seqs; ++s) {
588
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
589
+ llama_kv_cell & seq_meta = cells[seq_id];
590
  bool has_cell = false;
591
  if (seq_meta.tail >= 0) {
592
+ llama_kv_cell & cell = cells[seq_meta.tail];
593
  GGML_ASSERT(cell.has_seq_id(seq_id));
594
  // does this seq_id "own" the cell?
595
  if (cell.seq_id.size() == 1) { has_cell = true; }
596
  }
597
  if (!has_cell) {
598
+ llama_kv_cell & empty_cell = cells[next_empty_cell];
599
  GGML_ASSERT(empty_cell.is_empty());
600
  // copy old tail into the empty cell
601
  if (seq_meta.tail >= 0) {
602
+ llama_kv_cell & orig_cell = cells[seq_meta.tail];
603
  empty_cell.pos = orig_cell.pos;
604
  empty_cell.src = orig_cell.src;
605
  orig_cell.seq_id.erase(seq_id);
 
609
  // find next empty cell
610
  if (s + 1 < n_seqs) {
611
  next_empty_cell += 1;
612
+ for (uint32_t i = 0; i < size; ++i) {
613
+ if (next_empty_cell >= size) { next_empty_cell -= size; }
614
+ llama_kv_cell & cell = cells[next_empty_cell];
615
  if (cell.is_empty()) { break; }
616
  next_empty_cell += 1;
617
  }
 
624
  // gather and re-order
625
  for (uint32_t s = 0; s < n_seqs; ++s) {
626
  int32_t dst_id = s + min;
627
+ int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
628
  if (dst_id != src_id) {
629
+ llama_kv_cell & dst_cell = cells[dst_id];
630
+ llama_kv_cell & src_cell = cells[src_id];
631
 
632
  std::swap(dst_cell.pos, src_cell.pos);
633
  std::swap(dst_cell.src, src_cell.src);
 
635
 
636
  // swap tails (assuming they NEVER overlap)
637
  for (const llama_seq_id seq_id : src_cell.seq_id) {
638
+ cells[seq_id].tail = src_id;
639
  }
640
  for (const llama_seq_id seq_id : dst_cell.seq_id) {
641
+ cells[seq_id].tail = dst_id;
642
  }
643
  }
644
  }
 
647
  for (uint32_t s = 0; s < n_seqs; ++s) {
648
  const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
649
  int32_t cell_id = s + min;
650
+ llama_kv_cell & cell = cells[cell_id];
651
 
652
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
653
  // What should happen when the pos backtracks or skips a value?
 
660
  for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
661
  const llama_seq_id seq_id = ubatch.seq_id[s][j];
662
  cell.seq_id.insert(seq_id);
663
+ cells[seq_id].tail = cell_id;
664
  }
665
  }
666
 
667
  // allow getting the range of used cells, from head to head + n
668
+ head = min;
669
+ n = max - min + 1;
670
+ used = std::count_if(cells.begin(), cells.end(),
671
  [](const llama_kv_cell& cell){ return !cell.is_empty(); });
672
 
673
  // sanity check
674
+ return n >= n_seqs;
675
  }
676
+
677
  // otherwise, one cell per token.
678
 
679
+ if (n_tokens > size) {
680
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
681
+ return false;
682
  }
683
 
684
  uint32_t n_tested = 0;
685
 
686
  while (true) {
687
+ if (head + n_tokens > size) {
688
+ n_tested += size - head;
689
+ head = 0;
690
  continue;
691
  }
692
 
693
  bool found = true;
694
  for (uint32_t i = 0; i < n_tokens; i++) {
695
+ if (cells[head + i].pos >= 0) {
696
  found = false;
697
+ head += i + 1;
698
+ n_tested += i + 1;
699
  break;
700
  }
701
  }
 
704
  break;
705
  }
706
 
707
+ if (n_tested >= size) {
708
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
709
+ return false;
710
  }
711
  }
712
 
713
  for (uint32_t s = 0; s < n_seqs; s++) {
714
  for (uint32_t i = 0; i < n_seq_tokens; ++i) {
715
  uint32_t k = s*n_seq_tokens + i;
716
+ cells[head + k].pos = ubatch.pos[k];
717
 
718
  for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
719
+ cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
720
  }
721
  }
722
  }
723
 
724
+ used += n_tokens;
725
+
726
+ pending.ranges.push_back({head, head + n_tokens});
727
+
728
+ return true;
729
+ }
730
 
731
+ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
732
+ // the FA kernels require padding to avoid extra runtime boundary checks
733
+ return cparams.flash_attn ? 256u : 32u;
734
  }
735
 
736
+ uint32_t llama_kv_cache_unified::cell_max() const {
737
+ for (uint32_t i = size; i > 0; --i) {
738
+ const llama_kv_cell & cell = cells[i - 1];
739
 
740
  if (cell.pos >= 0 && !cell.is_empty()) {
741
  return i;
 
745
  return 0;
746
  }
747
 
748
+ size_t llama_kv_cache_unified::size_k_bytes() const {
749
+ size_t size_k_bytes = 0;
750
+
751
+ for (const auto & k : k_l) {
752
+ size_k_bytes += ggml_nbytes(k);
 
753
  }
 
 
754
 
755
+ return size_k_bytes;
756
+ }
757
+
758
+ size_t llama_kv_cache_unified::size_v_bytes() const {
759
+ size_t size_v_bytes = 0;
760
+
761
+ for (const auto & v : v_l) {
762
+ size_v_bytes += ggml_nbytes(v);
763
  }
764
+
765
+ return size_v_bytes;
766
  }
767
 
768
+ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
769
+ const uint32_t n_layer = hparams.n_layer;
 
 
 
 
770
 
771
+ const uint32_t n_kv = cell_max();
772
+ const uint32_t n_used = used;
773
 
774
+ assert(n_used <= n_kv);
775
+
776
+ //const int64_t t_start = ggml_time_us();
777
+
778
+ // number of cells moved
779
+ uint32_t n_moves = 0;
780
+
781
+ // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
782
+ // - source view, destination view, copy operation
783
+ // - x2 for keys and values
784
+ //const uint32_t max_moves = max_nodes()/(6*n_layer);
785
+ // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
786
+ const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
787
+
788
+ // determine which KV cells to move where
789
+ //
790
+ // cell i moves to ids[i]
791
+ //
792
+ // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
793
+ //
794
+ auto & ids = defrag_info.ids;
795
+
796
+ ids.clear();
797
+ ids.resize(n_kv, n_kv);
798
+
799
+ for (uint32_t i0 = 0; i0 < n_used; ++i0) {
800
+ const auto & cell0 = cells[i0];
801
+
802
+ if (!cell0.is_empty()) {
803
+ ids[i0] = i0;
804
+
805
+ continue;
806
  }
807
+
808
+ // found a hole - fill it with data from the end of the cache
809
+
810
+ uint32_t nh = 1;
811
+
812
+ // determine the size of the hole
813
+ while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
814
+ nh++;
 
 
 
 
 
 
 
 
 
 
815
  }
 
816
 
817
+ uint32_t nf = 0;
818
+ uint32_t is = n_kv - 1;
819
+
820
+ // starting from the end, find nh non-empty cells
821
+ for (; is > i0; --is) {
822
+ const auto & cell1 = cells[is];
823
+
824
+ if (cell1.is_empty() || ids[is] != n_kv) {
825
  continue;
826
  }
 
 
 
827
 
828
+ // non-empty cell which is not yet moved
829
+ nf++;
830
+
831
+ if (nf == nh) {
832
+ break;
833
  }
834
  }
 
835
 
836
+ // this can only happen if `n_used` is not accurate, which would be a bug
837
+ GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
838
 
839
+ nf = 0;
 
840
 
841
+ uint32_t i1 = is;
 
 
 
 
 
 
 
842
 
843
+ // are we moving a continuous block of memory?
844
+ bool cont = false;
 
 
 
 
 
845
 
846
+ // should we stop searching for the next move?
847
+ bool stop = false;
848
+
849
+ // go back and move the nf cells to the hole
850
+ for (; i1 < n_kv; ++i1) {
851
+ auto & cell1 = cells[i1];
852
+
853
+ if (cell1.is_empty() || ids[i1] != n_kv) {
854
+ if (n_moves == max_moves) {
855
+ stop = true;
856
+ break;
857
  }
858
+
859
+ cont = false;
860
+ continue;
861
  }
 
 
862
 
863
+ // this cell goes to (i0 + nf)
864
+ ids[i1] = i0 + nf;
865
+
866
+ // move the cell meta data
867
+ cells[i0 + nf] = cell1;
868
+
869
+ // clear the old cell and move the head there
870
+ cell1 = llama_kv_cell();
871
+ head = n_used;
872
+
873
+ if (!cont) {
874
+ n_moves++;
875
+ cont = true;
876
+ }
877
+
878
+ nf++;
879
+
880
+ if (nf == nh) {
881
+ break;
882
  }
883
  }
884
 
885
+ if (stop || n_moves == max_moves) {
886
+ break;
887
+ }
888
+
889
+ //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
890
+
891
+ i0 += nh - 1;
892
+ }
893
+
894
+ if (n_moves == 0) {
895
+ return false;
896
  }
 
897
 
898
+ LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
899
 
900
+ LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
901
+
902
+ return true;
903
+ }
904
+
905
+ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
906
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
907
+ uint32_t cell_count = 0;
908
+
909
+ // Count the number of cells with the specified seq_id
910
+ // Find all the ranges of cells with this seq id (or all, when -1)
911
+ uint32_t cell_range_begin = size;
912
+ for (uint32_t i = 0; i < size; ++i) {
913
+ const auto & cell = cells[i];
914
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
915
+ ++cell_count;
916
+ if (cell_range_begin == size) {
917
+ cell_range_begin = i;
918
+ }
919
+ } else {
920
+ if (cell_range_begin != size) {
921
+ cell_ranges.emplace_back(cell_range_begin, i);
922
+ cell_range_begin = size;
923
+ }
924
  }
925
  }
926
+ if (cell_range_begin != size) {
927
+ cell_ranges.emplace_back(cell_range_begin, size);
928
+ }
929
+
930
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
931
+ uint32_t cell_count_check = 0;
932
+ for (const auto & range : cell_ranges) {
933
+ cell_count_check += range.second - range.first;
934
+ }
935
+ GGML_ASSERT(cell_count == cell_count_check);
936
+
937
+ io.write(&cell_count, sizeof(cell_count));
938
+
939
+ state_write_meta(io, cell_ranges, seq_id);
940
+ state_write_data(io, cell_ranges);
941
  }
942
 
943
+ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
944
+ uint32_t cell_count;
945
+ io.read_to(&cell_count, sizeof(cell_count));
946
 
947
+ bool res = true;
948
+ res = res && state_read_meta(io, cell_count, seq_id);
949
+ res = res && state_read_data(io, cell_count);
950
+
951
+ if (!res) {
952
+ if (seq_id == -1) {
953
+ clear();
 
 
 
954
  } else {
955
+ seq_rm(seq_id, -1, -1);
 
956
  }
957
+ throw std::runtime_error("failed to restore kv cache");
958
  }
 
 
 
959
  }
960
 
961
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
962
+ for (const auto & range : cell_ranges) {
963
+ for (uint32_t i = range.first; i < range.second; ++i) {
964
+ const auto & cell = cells[i];
965
+ const llama_pos pos = cell.pos;
966
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
 
967
 
968
+ io.write(&pos, sizeof(pos));
969
+ io.write(&n_seq_id, sizeof(n_seq_id));
 
 
970
 
971
+ if (n_seq_id) {
972
+ for (auto seq_id : cell.seq_id) {
973
+ io.write(&seq_id, sizeof(seq_id));
 
 
 
 
 
974
  }
975
  }
976
  }
 
977
  }
978
+ }
979
 
980
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
981
+ const uint32_t v_trans = this->v_trans ? 1 : 0;
982
+ const uint32_t n_layer = hparams.n_layer;
 
 
983
 
984
+ io.write(&v_trans, sizeof(v_trans));
985
+ io.write(&n_layer, sizeof(n_layer));
986
+
987
+ std::vector<uint8_t> tmp_buf;
988
+
989
+ // Iterate and write all the keys first, each row is a cell
990
+ // Get whole range at a time
991
+ for (uint32_t il = 0; il < n_layer; ++il) {
992
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
993
+
994
+ // Write key type
995
+ const int32_t k_type_i = (int32_t)k_l[il]->type;
996
+ io.write(&k_type_i, sizeof(k_type_i));
997
+
998
+ // Write row size of key
999
+ const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1000
+ io.write(&k_size_row, sizeof(k_size_row));
1001
+
1002
+ // Read each range of cells of k_size length each into tmp_buf and write out
1003
+ for (const auto & range : cell_ranges) {
1004
+ const size_t range_size = range.second - range.first;
1005
+ const size_t buf_size = range_size * k_size_row;
1006
+ io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
1007
  }
1008
  }
1009
 
1010
+ if (!v_trans) {
1011
+ for (uint32_t il = 0; il < n_layer; ++il) {
1012
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
1013
 
1014
+ // Write value type
1015
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1016
+ io.write(&v_type_i, sizeof(v_type_i));
 
 
 
 
 
 
 
1017
 
1018
+ // Write row size of value
1019
+ const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1020
+ io.write(&v_size_row, sizeof(v_size_row));
1021
+
1022
+ // Read each range of cells of v_size length each into tmp_buf and write out
1023
+ for (const auto & range : cell_ranges) {
1024
+ const size_t range_size = range.second - range.first;
1025
+ const size_t buf_size = range_size * v_size_row;
1026
+ io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
1027
+ }
1028
+ }
1029
+ } else {
1030
+ // When v is transposed, we also need the element size and get the element ranges from each row
1031
+ const uint32_t kv_size = size;
1032
+ for (uint32_t il = 0; il < n_layer; ++il) {
1033
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1034
+
1035
+ // Write value type
1036
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1037
+ io.write(&v_type_i, sizeof(v_type_i));
1038
+
1039
+ // Write element size
1040
+ const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
1041
+ io.write(&v_size_el, sizeof(v_size_el));
1042
+
1043
+ // Write GQA embedding size
1044
+ io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1045
+
1046
+ // For each row, we get the element values of each cell
1047
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1048
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
1049
+ for (const auto & range : cell_ranges) {
1050
+ const size_t range_size = range.second - range.first;
1051
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1052
+ const size_t buf_size = range_size * v_size_el;
1053
+ io.write_tensor(v_l[il], src_offset, buf_size);
1054
  }
1055
  }
1056
  }
 
1057
  }
1058
+ }
1059
 
1060
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
1061
+ if (dest_seq_id != -1) {
1062
+ // single sequence
1063
 
1064
+ seq_rm(dest_seq_id, -1, -1);
1065
+
1066
+ llama_sbatch sbatch;
1067
+ llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1068
+
1069
+ batch.n_tokens = cell_count;
1070
+ batch.n_seq_tokens = cell_count;
1071
+ batch.n_seqs = 1;
1072
+
1073
+ for (uint32_t i = 0; i < cell_count; ++i) {
1074
+ llama_pos pos;
1075
+ uint32_t n_seq_id;
1076
+
1077
+ io.read_to(&pos, sizeof(pos));
1078
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1079
+
1080
+ if (n_seq_id != 0) {
1081
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1082
+ return false;
1083
  }
1084
+
1085
+ batch.pos[i] = pos;
1086
+ }
1087
+ batch.n_seq_id[0] = 1;
1088
+ batch.seq_id[0] = &dest_seq_id;
1089
+ if (!find_slot(batch)) {
1090
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1091
+ return false;
1092
+ }
1093
+ commit();
1094
+
1095
+ // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1096
+ // Assume that this is one contiguous block of cells
1097
+ GGML_ASSERT(head + cell_count <= size);
1098
+ GGML_ASSERT(cells[head].pos == batch.pos[0]);
1099
+ GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1100
+ GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
1101
+ GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
1102
+ } else {
1103
+ // whole KV cache restore
1104
+
1105
+ if (cell_count > size) {
1106
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1107
+ return false;
1108
  }
 
 
1109
 
1110
+ clear();
1111
+
1112
+ for (uint32_t i = 0; i < cell_count; ++i) {
1113
+ llama_kv_cell & cell = cells[i];
1114
+
1115
+ llama_pos pos;
1116
+ uint32_t n_seq_id;
1117
+
1118
+ io.read_to(&pos, sizeof(pos));
1119
+ io.read_to(&n_seq_id, sizeof(n_seq_id));
1120
+
1121
+ cell.pos = pos;
1122
 
1123
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
1124
+ llama_seq_id seq_id;
1125
+ io.read_to(&seq_id, sizeof(seq_id));
1126
+
1127
+ // TODO: llama_kv_cache_unified should have a notion of max sequences
1128
+ //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1129
+ if (seq_id < 0) {
1130
+ //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1131
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1132
+ return false;
1133
+ }
1134
+
1135
+ cell.seq_id.insert(seq_id);
1136
+
1137
+ if (recurrent) {
1138
+ int32_t & tail = cells[seq_id].tail;
1139
+ if (tail != -1) {
1140
+ LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1141
+ return false;
1142
+ }
1143
+ tail = i;
1144
+ }
1145
+ }
1146
  }
1147
+
1148
+ head = 0;
1149
+ used = cell_count;
1150
  }
1151
 
1152
+ if (recurrent) {
1153
+ for (uint32_t i = 0; i < cell_count; ++i) {
1154
+ uint32_t cell_id = head + i;
1155
+ // make sure the recurrent states will keep their restored state
1156
+ cells[cell_id].src = cell_id;
1157
+ }
1158
+ }
1159
+
1160
+ return true;
1161
  }
1162
 
1163
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1164
+ uint32_t v_trans;
1165
+ uint32_t n_layer;
1166
+ io.read_to(&v_trans, sizeof(v_trans));
1167
+ io.read_to(&n_layer, sizeof(n_layer));
1168
+
1169
+ if (n_layer != hparams.n_layer) {
1170
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1171
+ return false;
1172
+ }
1173
+ if (cell_count > size) {
1174
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1175
+ return false;
1176
+ }
1177
+ if (v_trans != (bool) v_trans) {
1178
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1179
+ return false;
1180
  }
 
1181
 
1182
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1183
+ for (uint32_t il = 0; il < n_layer; ++il) {
1184
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1185
 
1186
+ // Read type of key
1187
+ int32_t k_type_i_ref;
1188
+ io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1189
+ const int32_t k_type_i = (int32_t) k_l[il]->type;
1190
+ if (k_type_i != k_type_i_ref) {
1191
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1192
+ return false;
1193
+ }
1194
+
1195
+ // Read row size of key
1196
+ uint64_t k_size_row_ref;
1197
+ io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1198
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1199
+ if (k_size_row != k_size_row_ref) {
1200
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1201
+ return false;
1202
+ }
1203
+
1204
+ if (cell_count) {
1205
+ // Read and set the keys for the whole cell range
1206
+ ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1207
+ }
1208
  }
1209
 
1210
+ if (!v_trans) {
1211
+ for (uint32_t il = 0; il < n_layer; ++il) {
1212
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1213
 
1214
+ // Read type of value
1215
+ int32_t v_type_i_ref;
1216
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1217
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1218
+ if (v_type_i != v_type_i_ref) {
1219
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1220
+ return false;
1221
+ }
1222
+
1223
+ // Read row size of value
1224
+ uint64_t v_size_row_ref;
1225
+ io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1226
+ const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1227
+ if (v_size_row != v_size_row_ref) {
1228
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1229
+ return false;
1230
+ }
1231
 
1232
+ if (cell_count) {
1233
+ // Read and set the values for the whole cell range
1234
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1235
+ }
1236
+ }
1237
+ } else {
1238
+ // For each layer, read the values for each cell (transposed)
1239
+ for (uint32_t il = 0; il < n_layer; ++il) {
1240
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1241
+
1242
+ // Read type of value
1243
+ int32_t v_type_i_ref;
1244
+ io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1245
+ const int32_t v_type_i = (int32_t)v_l[il]->type;
1246
+ if (v_type_i != v_type_i_ref) {
1247
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1248
+ return false;
1249
+ }
1250
+
1251
+ // Read element size of value
1252
+ uint32_t v_size_el_ref;
1253
+ io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1254
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
1255
+ if (v_size_el != v_size_el_ref) {
1256
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1257
+ return false;
1258
+ }
1259
+
1260
+ // Read GQA embedding size
1261
+ uint32_t n_embd_v_gqa_ref;
1262
+ io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1263
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1264
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1265
+ return false;
1266
+ }
1267
+
1268
+ if (cell_count) {
1269
+ // For each row in the transposed matrix, read the values for the whole cell range
1270
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1271
+ const size_t dst_offset = (head + j * size) * v_size_el;
1272
+ ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1273
+ }
1274
+ }
1275
+ }
1276
+ }
1277
+
1278
+ return true;
1279
  }
1280
 
1281
  //
1282
  // kv cache view
1283
  //
1284
 
1285
+ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
1286
+ llama_kv_cache_view result = {
1287
  /*.n_cells = */ 0,
1288
  /*.n_seq_max = */ n_seq_max,
1289
  /*.token_count = */ 0,
1290
+ /*.used_cells = */ kv.get_used_cells(),
1291
  /*.max_contiguous = */ 0,
1292
  /*.max_contiguous_idx = */ -1,
1293
  /*.cells = */ nullptr,
 
1297
  return result;
1298
  }
1299
 
1300
+ void llama_kv_cache_view_free(llama_kv_cache_view * view) {
1301
  if (view->cells != nullptr) {
1302
  free(view->cells);
1303
  view->cells = nullptr;
 
1308
  }
1309
  }
1310
 
1311
+ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
1312
+ // TODO: rework this in the future, for now quick hack
1313
+ const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
1314
+ if (kvu == nullptr) {
1315
+ LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
1316
+ return;
1317
+ }
1318
+
1319
+ if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
1320
+ view->n_cells = int32_t(kvu->size);
1321
+ void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
1322
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
1323
+ view->cells = (llama_kv_cache_view_cell *)p;
1324
  p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
1325
  GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
1326
  view->cells_sequences = (llama_seq_id *)p;
1327
  }
1328
 
1329
+ const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
1330
  llama_kv_cache_view_cell * c_curr = view->cells;
1331
  llama_seq_id * cs_curr = view->cells_sequences;
1332
  int32_t used_cells = 0;
 
1335
  uint32_t max_contig = 0;
1336
  int32_t max_contig_idx = -1;
1337
 
1338
+ for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
1339
  const size_t curr_size = kv_cells[i].seq_id.size();
1340
  token_count += curr_size;
1341
  c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
 
1373
  view->max_contiguous_idx = max_contig_idx;
1374
  view->token_count = token_count;
1375
  view->used_cells = used_cells;
1376
+ if (uint32_t(used_cells) != kvu->used) {
1377
  LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
1378
+ __func__, kvu->used, used_cells);
1379
  }
1380
  }
examples/talk-llama/llama-kv-cache.h CHANGED
@@ -1,15 +1,51 @@
1
  #pragma once
2
 
3
  #include "llama.h"
 
 
4
 
5
  #include "ggml-cpp.h"
6
 
 
7
  #include <set>
8
  #include <vector>
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  struct llama_kv_cell {
11
  llama_pos pos = -1;
12
- llama_pos delta = 0;
13
  int32_t src = -1; // used by recurrent state models to copy states
14
  int32_t tail = -1;
15
 
@@ -29,190 +65,149 @@ struct llama_kv_cell {
29
  };
30
 
31
  // ring-buffer of cached KV data
32
- struct llama_kv_cache {
33
- bool has_shift = false;
34
- bool do_defrag = false;
35
- bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
36
- bool v_trans = true; // the value tensor is transposed
37
- bool can_shift = false;
38
-
39
- // Note: The value of head isn't only used to optimize searching
40
- // for a free KV slot. llama_decode_internal also uses it, so it
41
- // cannot be freely changed after a slot has been allocated.
42
- uint32_t head = 0;
43
- uint32_t size = 0;
44
- uint32_t used = 0; // used cells (i.e. at least one seq_id)
 
 
 
 
 
 
 
 
 
 
45
 
46
- // computed before each graph build
47
- uint32_t n = 0;
48
 
49
- ggml_type type_k = GGML_TYPE_F16;
50
- ggml_type type_v = GGML_TYPE_F16;
51
 
52
- std::vector<llama_kv_cell> cells;
 
53
 
54
- std::vector<struct ggml_tensor *> k_l; // per layer
55
- std::vector<struct ggml_tensor *> v_l;
56
 
57
- std::vector<ggml_context_ptr> ctxs;
58
- std::vector<ggml_backend_buffer_ptr> bufs;
59
 
60
- size_t total_size() const {
61
- size_t size = 0;
62
- for (const auto & buf : bufs) {
63
- size += ggml_backend_buffer_get_size(buf.get());
64
- }
65
 
66
- return size;
67
- }
68
 
69
- // TODO: better data structures to reduce the cost of this operation
70
- llama_pos max_pos() const {
71
- llama_pos max_pos = -1;
72
- for (const auto & cell : cells) {
73
- max_pos = std::max(max_pos, cell.pos);
74
- }
75
 
76
- return max_pos;
77
- }
78
- };
 
 
79
 
80
- // a structure holds information about the slot found in llama_kv_cache_find_slot
81
- struct llama_kv_cache_slot_info {
82
- std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
83
- bool found = false; // the slot was found
84
 
85
- explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
86
- llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
87
 
88
- operator bool() const { return found; }
89
- };
90
 
91
- // TODO: maybe not needed
92
- uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
93
 
94
- bool llama_kv_cache_init(
95
- struct llama_kv_cache & cache,
96
- const llama_model & model,
97
- const llama_cparams & cparams,
98
- ggml_type type_k,
99
- ggml_type type_v,
100
- uint32_t kv_size,
101
- bool offload);
102
 
103
- // find an empty slot of size "n_tokens" in the cache
104
- // updates the cache head
105
- // returns a structure holding information about the slot found
106
- // Note: On success, it's important that cache.head points
107
- // to the first cell of the slot.
108
- struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
109
- struct llama_kv_cache & cache,
110
- const struct llama_ubatch & batch);
111
 
112
- // find how many cells are currently in use
113
- uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
114
 
115
- void llama_kv_cache_clear(struct llama_kv_cache & cache);
 
 
 
116
 
117
- bool llama_kv_cache_seq_rm(
118
- struct llama_kv_cache & cache,
119
- llama_seq_id seq_id,
120
- llama_pos p0,
121
- llama_pos p1);
122
 
123
- void llama_kv_cache_seq_cp(
124
- struct llama_kv_cache & cache,
125
- llama_seq_id seq_id_src,
126
- llama_seq_id seq_id_dst,
127
- llama_pos p0,
128
- llama_pos p1);
129
 
130
- void llama_kv_cache_seq_keep(
131
- struct llama_kv_cache & cache,
132
- llama_seq_id seq_id);
133
 
134
- void llama_kv_cache_seq_add(
135
- struct llama_kv_cache & cache,
136
- llama_seq_id seq_id,
137
- llama_pos p0,
138
- llama_pos p1,
139
- llama_pos delta);
140
 
141
- void llama_kv_cache_seq_div(
142
- struct llama_kv_cache & cache,
143
- llama_seq_id seq_id,
144
- llama_pos p0,
145
- llama_pos p1,
146
- int d);
147
 
148
- llama_pos llama_kv_cache_seq_pos_max(
149
- struct llama_kv_cache & cache,
150
- llama_seq_id seq_id);
151
 
152
- void llama_kv_cache_defrag(struct llama_kv_cache & cache);
 
153
 
154
- int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
 
155
 
156
- int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
 
157
 
158
- bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
 
 
 
 
 
159
 
160
- //
161
- // kv cache view
162
- //
163
 
164
- struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
165
 
166
- void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
 
167
 
168
- //
169
- // kv cache restore
170
- //
171
 
172
- // saves the kv_cache state for future recovery.
173
- // used to rollback llama_kv_cache_find_slot changes.
174
- struct llama_kv_slot_restorer {
175
- struct llama_kv_cache_state {
176
- uint32_t head = 0;
177
- uint32_t n = 0;
178
- } old_state;
179
 
180
- // for non-recurrent models only
181
- // list of slots to restore
182
- std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
183
 
184
- bool do_restore = false;
 
 
185
 
186
- explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
187
- old_state.head = cache.head;
188
- old_state.n = cache.n;
189
- }
 
190
 
191
- // saves a slot information for future restoration
192
- void save(const struct llama_kv_cache_slot_info & slot) {
193
- if (slot) {
194
- do_restore = true;
195
- if (slot.boundaries.first != slot.boundaries.second) {
196
- slot_boundaries.push_back(slot.boundaries);
197
- }
198
- }
199
- }
200
 
201
- // must be explicitly called to restore the kv_cache state
202
- // and rollback changes from all llama_kv_cache_find_slot calls
203
- void restore(struct llama_kv_cache & cache) {
204
- if (do_restore) {
205
- cache.head = old_state.head;
206
- cache.n = old_state.n;
207
-
208
- if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
209
- llama_kv_cache_seq_rm(cache, -1, -1, -1);
210
- } else {
211
- for (auto & slot : slot_boundaries) {
212
- llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
213
- }
214
- }
215
- }
216
- }
217
- };
218
 
 
 
1
  #pragma once
2
 
3
  #include "llama.h"
4
+ #include "llama-io.h"
5
+ #include "llama-memory.h"
6
 
7
  #include "ggml-cpp.h"
8
 
9
+ #include <functional>
10
  #include <set>
11
  #include <vector>
12
 
13
+ struct llama_cparams;
14
+ struct llama_hparams;
15
+ struct llama_ubatch;
16
+
17
+ struct llama_kv_cache : public llama_memory_i {
18
+ using llama_memory_i::llama_memory_i;
19
+
20
+ virtual void restore() = 0; // call if batch processing fails - restores the cache state
21
+ virtual void commit() = 0; // call after successful batch processing - clears any pending state
22
+
23
+ virtual int32_t get_n_tokens() const = 0;
24
+ virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
25
+
26
+ virtual bool get_can_shift() const = 0;
27
+
28
+ bool get_can_edit() const override { return get_can_shift(); }
29
+ };
30
+
31
+ struct llama_kv_cache_guard {
32
+ llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
33
+
34
+ ~llama_kv_cache_guard() {
35
+ kv->restore();
36
+ }
37
+
38
+ void commit() {
39
+ kv->commit();
40
+ }
41
+
42
+ private:
43
+ llama_kv_cache * kv;
44
+ };
45
+
46
  struct llama_kv_cell {
47
  llama_pos pos = -1;
48
+ llama_pos delta = 0;
49
  int32_t src = -1; // used by recurrent state models to copy states
50
  int32_t tail = -1;
51
 
 
65
  };
66
 
67
  // ring-buffer of cached KV data
68
+ // TODO: pimpl
69
+ // TODO: add notion of max sequences
70
+ class llama_kv_cache_unified : public llama_kv_cache {
71
+ public:
72
+ // can be used to query data from the model if needed
73
+ struct callbacks {
74
+ std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
75
+ };
76
+
77
+ llama_kv_cache_unified(
78
+ const llama_hparams & hparams,
79
+ callbacks cbs);
80
+
81
+ virtual ~llama_kv_cache_unified() = default;
82
+
83
+ // TODO: become constructor
84
+ bool init(
85
+ const llama_model & model, // TODO: do not reference the model
86
+ const llama_cparams & cparams,
87
+ ggml_type type_k,
88
+ ggml_type type_v,
89
+ uint32_t kv_size,
90
+ bool offload);
91
 
92
+ int32_t get_n_tokens() const override;
93
+ int32_t get_used_cells() const override;
94
 
95
+ size_t total_size() const;
 
96
 
97
+ // TODO: better data structures to reduce the cost of this operation
98
+ llama_pos pos_max() const;
99
 
100
+ void clear() override;
101
+ void defrag() override;
102
 
103
+ virtual void restore() override;
104
+ virtual void commit() override;
105
 
106
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
107
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
108
+ void seq_keep(llama_seq_id seq_id) override;
109
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
110
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
111
 
112
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
113
 
114
+ bool get_can_shift() const override;
 
 
 
 
 
115
 
116
+ // find an empty slot of size "n_tokens" in the cache
117
+ // updates the cache head
118
+ // Note: On success, it's important that cache.head points
119
+ // to the first cell of the slot.
120
+ bool find_slot(const llama_ubatch & batch);
121
 
122
+ // TODO: maybe not needed
123
+ uint32_t get_padding(const llama_cparams & cparams) const;
 
 
124
 
125
+ // find how many cells are currently in use
126
+ uint32_t cell_max() const;
127
 
128
+ size_t size_k_bytes() const;
129
+ size_t size_v_bytes() const;
130
 
131
+ // defrag
 
132
 
133
+ struct {
134
+ std::vector<uint32_t> ids;
135
+ } defrag_info;
 
 
 
 
 
136
 
137
+ // return true if cells have been moved
138
+ bool defrag_prepare(int32_t n_max_nodes);
 
 
 
 
 
 
139
 
140
+ // commit/restore cache
 
141
 
142
+ struct slot_range {
143
+ uint32_t c0 = 0; // note: these are cell indices, not sequence positions
144
+ uint32_t c1 = 0;
145
+ };
146
 
147
+ // pending cell updates that are not yet committed
148
+ struct {
149
+ std::vector<slot_range> ranges;
150
+ } pending;
 
151
 
152
+ // state write/load
 
 
 
 
 
153
 
154
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
155
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
 
156
 
157
+ // members
 
 
 
 
 
158
 
159
+ const llama_hparams & hparams;
 
 
 
 
 
160
 
161
+ callbacks cbs;
 
 
162
 
163
+ bool has_shift = false;
164
+ bool do_defrag = false;
165
 
166
+ // TODO: remove this and implement llama_kv_cache_recurrent instead
167
+ bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
168
 
169
+ bool v_trans = true; // the value tensor is transposed
170
+ bool can_shift = false;
171
 
172
+ // Note: The value of head isn't only used to optimize searching
173
+ // for a free KV slot. llama_decode_impl also uses it, so it
174
+ // cannot be freely changed after a slot has been allocated.
175
+ uint32_t head = 0;
176
+ uint32_t size = 0;
177
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
178
 
179
+ // computed before each graph build
180
+ uint32_t n = 0;
 
181
 
182
+ std::vector<llama_kv_cell> cells;
183
 
184
+ std::vector<ggml_tensor *> k_l; // per layer
185
+ std::vector<ggml_tensor *> v_l;
186
 
187
+ private:
188
+ ggml_type type_k = GGML_TYPE_F16;
189
+ ggml_type type_v = GGML_TYPE_F16;
190
 
191
+ std::vector<ggml_context_ptr> ctxs;
192
+ std::vector<ggml_backend_buffer_ptr> bufs;
 
 
 
 
 
193
 
194
+ void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
195
+ void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
 
196
 
197
+ bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
198
+ bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
199
+ };
200
 
201
+ // TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
202
+ //class llama_kv_cache_recurrent : public llama_kv_cache_unified {
203
+ //public:
204
+ // using llama_kv_cache_unified::llama_kv_cache_unified;
205
+ //};
206
 
207
+ //
208
+ // kv cache view
209
+ //
 
 
 
 
 
 
210
 
211
+ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
examples/talk-llama/llama-memory.cpp ADDED
@@ -0,0 +1 @@
 
 
1
+ #include "llama-memory.h"
examples/talk-llama/llama-memory.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ // general concept of LLM memory
6
+ // the KV cache is a type of LLM memory, but there can be other types
7
+ class llama_memory_i {
8
+ public:
9
+ virtual void clear() = 0;
10
+ virtual void defrag() = 0;
11
+
12
+ virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
13
+ virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
14
+ virtual void seq_keep(llama_seq_id seq_id) = 0;
15
+ virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
16
+ virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
17
+
18
+ virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
19
+
20
+ virtual bool get_can_edit() const = 0;
21
+ };
examples/talk-llama/llama-mmap.cpp CHANGED
@@ -8,6 +8,7 @@
8
  #include <climits>
9
  #include <stdexcept>
10
  #include <cerrno>
 
11
 
12
  #ifdef __has_include
13
  #if __has_include(<unistd.h>)
@@ -34,6 +35,10 @@
34
  #include <io.h>
35
  #endif
36
 
 
 
 
 
37
  // TODO: consider moving to llama-impl.h if needed in more places
38
  #if defined(_WIN32)
39
  static std::string llama_format_win_err(DWORD err) {
@@ -471,7 +476,11 @@ struct llama_mlock::impl {
471
 
472
  char* errmsg = std::strerror(errno);
473
  bool suggest = (errno == ENOMEM);
474
-
 
 
 
 
475
  struct rlimit lock_limit;
476
  if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
477
  suggest = false;
@@ -479,6 +488,7 @@ struct llama_mlock::impl {
479
  if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
480
  suggest = false;
481
  }
 
482
 
483
  LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
484
  size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
 
8
  #include <climits>
9
  #include <stdexcept>
10
  #include <cerrno>
11
+ #include <algorithm>
12
 
13
  #ifdef __has_include
14
  #if __has_include(<unistd.h>)
 
35
  #include <io.h>
36
  #endif
37
 
38
+ #if defined(__APPLE__)
39
+ #include <TargetConditionals.h>
40
+ #endif
41
+
42
  // TODO: consider moving to llama-impl.h if needed in more places
43
  #if defined(_WIN32)
44
  static std::string llama_format_win_err(DWORD err) {
 
476
 
477
  char* errmsg = std::strerror(errno);
478
  bool suggest = (errno == ENOMEM);
479
+ #if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
480
+ // visionOS/tvOS dont't support RLIMIT_MEMLOCK
481
+ // Skip resource limit checks on visionOS/tvOS
482
+ suggest = false;
483
+ #else
484
  struct rlimit lock_limit;
485
  if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
486
  suggest = false;
 
488
  if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
489
  suggest = false;
490
  }
491
+ #endif
492
 
493
  LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
494
  size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
examples/talk-llama/llama-mmap.h CHANGED
@@ -1,5 +1,6 @@
1
  #pragma once
2
 
 
3
  #include <memory>
4
  #include <vector>
5
 
 
1
  #pragma once
2
 
3
+ #include <cstdint>
4
  #include <memory>
5
  #include <vector>
6
 
examples/talk-llama/llama-model-loader.cpp CHANGED
@@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
445
  std::vector<std::string> & splits,
446
  bool use_mmap,
447
  bool check_tensors,
448
- const struct llama_model_kv_override * param_overrides_p) {
 
449
  int trace = 0;
450
  if (getenv("LLAMA_TRACE")) {
451
  trace = atoi(getenv("LLAMA_TRACE"));
@@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
457
  }
458
  }
459
 
 
 
460
  // Load the main GGUF
461
  struct ggml_context * ctx = NULL;
462
  struct gguf_init_params params = {
@@ -600,7 +603,9 @@ llama_model_loader::llama_model_loader(
600
 
601
  if (trace > 0) {
602
  const uint16_t sid = w.idx;
603
- LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
 
 
604
  }
605
  }
606
 
@@ -640,9 +645,9 @@ llama_model_loader::llama_model_loader(
640
  ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
641
 
642
  {
643
- const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV
644
- if (kid >= 0) {
645
- ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid);
646
  }
647
  }
648
 
 
445
  std::vector<std::string> & splits,
446
  bool use_mmap,
447
  bool check_tensors,
448
+ const llama_model_kv_override * param_overrides_p,
449
+ const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
450
  int trace = 0;
451
  if (getenv("LLAMA_TRACE")) {
452
  trace = atoi(getenv("LLAMA_TRACE"));
 
458
  }
459
  }
460
 
461
+ tensor_buft_overrides = param_tensor_buft_overrides_p;
462
+
463
  // Load the main GGUF
464
  struct ggml_context * ctx = NULL;
465
  struct gguf_init_params params = {
 
603
 
604
  if (trace > 0) {
605
  const uint16_t sid = w.idx;
606
+ LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__,
607
+ sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(),
608
+ ggml_nbytes(tensor)/1024.0f/1024.0f);
609
  }
610
  }
611
 
 
645
  ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
646
 
647
  {
648
+ uint32_t ftype_val = 0;
649
+ if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) {
650
+ ftype = (llama_ftype) ftype_val;
651
  }
652
  }
653
 
examples/talk-llama/llama-model-loader.h CHANGED
@@ -77,8 +77,9 @@ struct llama_model_loader {
77
 
78
  llama_mmaps mappings;
79
 
80
- std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
81
- std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
 
82
 
83
  gguf_context_ptr meta;
84
  std::vector<ggml_context_ptr> contexts;
@@ -95,7 +96,8 @@ struct llama_model_loader {
95
  std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
96
  bool use_mmap,
97
  bool check_tensors,
98
- const struct llama_model_kv_override * param_overrides_p);
 
99
 
100
  template<typename T>
101
  typename std::enable_if<std::is_integral<T>::value, bool>::type
 
77
 
78
  llama_mmaps mappings;
79
 
80
+ std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
81
+ std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
82
+ const llama_model_tensor_buft_override * tensor_buft_overrides;
83
 
84
  gguf_context_ptr meta;
85
  std::vector<ggml_context_ptr> contexts;
 
96
  std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
97
  bool use_mmap,
98
  bool check_tensors,
99
+ const llama_model_kv_override * param_overrides_p,
100
+ const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
101
 
102
  template<typename T>
103
  typename std::enable_if<std::is_integral<T>::value, bool>::type
examples/talk-llama/llama-model.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama-model.h CHANGED
@@ -2,7 +2,9 @@
2
 
3
  #include "llama.h"
4
  #include "llama-arch.h"
 
5
  #include "llama-hparams.h"
 
6
  #include "llama-vocab.h"
7
 
8
  #include <memory>
@@ -10,6 +12,8 @@
10
  #include <unordered_map>
11
  #include <vector>
12
 
 
 
13
  struct llama_model_loader;
14
 
15
  // available models
@@ -25,6 +29,7 @@ enum llm_type {
25
  LLM_TYPE_109M,
26
  LLM_TYPE_137M,
27
  LLM_TYPE_160M,
 
28
  LLM_TYPE_220M,
29
  LLM_TYPE_250M,
30
  LLM_TYPE_270M,
@@ -39,8 +44,10 @@ enum llm_type {
39
  LLM_TYPE_1_4B,
40
  LLM_TYPE_1_5B,
41
  LLM_TYPE_1_6B,
 
42
  LLM_TYPE_2B,
43
  LLM_TYPE_2_8B,
 
44
  LLM_TYPE_3B,
45
  LLM_TYPE_4B,
46
  LLM_TYPE_6B,
@@ -78,6 +85,9 @@ enum llm_type {
78
  LLM_TYPE_10B_128x3_66B,
79
  LLM_TYPE_57B_A14B,
80
  LLM_TYPE_27B,
 
 
 
81
  };
82
 
83
  struct llama_layer_posnet {
@@ -161,6 +171,8 @@ struct llama_layer {
161
  struct ggml_tensor * wq_b = nullptr;
162
  struct ggml_tensor * wkv_a_mqa = nullptr;
163
  struct ggml_tensor * wkv_b = nullptr;
 
 
164
  struct ggml_tensor * wq_cross = nullptr;
165
  struct ggml_tensor * wk_cross = nullptr;
166
  struct ggml_tensor * wv_cross = nullptr;
@@ -256,6 +268,20 @@ struct llama_layer {
256
  struct ggml_tensor * time_mix_receptance_b = nullptr;
257
  struct ggml_tensor * time_mix_gate = nullptr;
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  struct ggml_tensor * time_mix_ln = nullptr;
260
  struct ggml_tensor * time_mix_ln_b = nullptr;
261
  struct ggml_tensor * time_mix_output = nullptr;
@@ -347,7 +373,7 @@ struct llama_model {
347
  std::string desc() const;
348
 
349
  size_t size() const;
350
- size_t max_nodes() const;
351
  size_t n_devices() const;
352
 
353
  // total number of parameters in the model
@@ -360,11 +386,26 @@ struct llama_model {
360
 
361
  ggml_backend_buffer_type_t select_buft(int il) const;
362
 
 
 
363
  const struct ggml_tensor * get_tensor(const char * name) const;
364
 
 
 
 
 
 
 
 
 
 
365
  private:
366
  struct impl;
367
  std::unique_ptr<impl> pimpl;
368
  };
369
 
370
  const char * llm_type_name(llm_type type);
 
 
 
 
 
2
 
3
  #include "llama.h"
4
  #include "llama-arch.h"
5
+ #include "llama-graph.h"
6
  #include "llama-hparams.h"
7
+ #include "llama-memory.h"
8
  #include "llama-vocab.h"
9
 
10
  #include <memory>
 
12
  #include <unordered_map>
13
  #include <vector>
14
 
15
+ struct llama_cparams;
16
+ struct llama_ubatch;
17
  struct llama_model_loader;
18
 
19
  // available models
 
29
  LLM_TYPE_109M,
30
  LLM_TYPE_137M,
31
  LLM_TYPE_160M,
32
+ LLM_TYPE_190M,
33
  LLM_TYPE_220M,
34
  LLM_TYPE_250M,
35
  LLM_TYPE_270M,
 
44
  LLM_TYPE_1_4B,
45
  LLM_TYPE_1_5B,
46
  LLM_TYPE_1_6B,
47
+ LLM_TYPE_1_8B,
48
  LLM_TYPE_2B,
49
  LLM_TYPE_2_8B,
50
+ LLM_TYPE_2_9B,
51
  LLM_TYPE_3B,
52
  LLM_TYPE_4B,
53
  LLM_TYPE_6B,
 
85
  LLM_TYPE_10B_128x3_66B,
86
  LLM_TYPE_57B_A14B,
87
  LLM_TYPE_27B,
88
+ LLM_TYPE_290B,
89
+ LLM_TYPE_17B_16E, // llama4 Scout
90
+ LLM_TYPE_17B_128E, // llama4 Maverick
91
  };
92
 
93
  struct llama_layer_posnet {
 
171
  struct ggml_tensor * wq_b = nullptr;
172
  struct ggml_tensor * wkv_a_mqa = nullptr;
173
  struct ggml_tensor * wkv_b = nullptr;
174
+ struct ggml_tensor * wk_b = nullptr;
175
+ struct ggml_tensor * wv_b = nullptr;
176
  struct ggml_tensor * wq_cross = nullptr;
177
  struct ggml_tensor * wk_cross = nullptr;
178
  struct ggml_tensor * wv_cross = nullptr;
 
268
  struct ggml_tensor * time_mix_receptance_b = nullptr;
269
  struct ggml_tensor * time_mix_gate = nullptr;
270
 
271
+ // rwkv7
272
+ struct ggml_tensor * time_mix_w0 = nullptr;
273
+ struct ggml_tensor * time_mix_a0 = nullptr;
274
+ struct ggml_tensor * time_mix_a1 = nullptr;
275
+ struct ggml_tensor * time_mix_a2 = nullptr;
276
+ struct ggml_tensor * time_mix_v0 = nullptr;
277
+ struct ggml_tensor * time_mix_v1 = nullptr;
278
+ struct ggml_tensor * time_mix_v2 = nullptr;
279
+ struct ggml_tensor * time_mix_g1 = nullptr;
280
+ struct ggml_tensor * time_mix_g2 = nullptr;
281
+ struct ggml_tensor * time_mix_k_k = nullptr;
282
+ struct ggml_tensor * time_mix_k_a = nullptr;
283
+ struct ggml_tensor * time_mix_r_k = nullptr;
284
+
285
  struct ggml_tensor * time_mix_ln = nullptr;
286
  struct ggml_tensor * time_mix_ln_b = nullptr;
287
  struct ggml_tensor * time_mix_output = nullptr;
 
373
  std::string desc() const;
374
 
375
  size_t size() const;
376
+ size_t n_tensors() const;
377
  size_t n_devices() const;
378
 
379
  // total number of parameters in the model
 
386
 
387
  ggml_backend_buffer_type_t select_buft(int il) const;
388
 
389
+ bool has_tensor_overrides() const;
390
+
391
  const struct ggml_tensor * get_tensor(const char * name) const;
392
 
393
+ // TODO: move this to new llm_arch_model_i interface
394
+ llama_memory_i * create_memory() const; // TODO: params
395
+
396
+ // TODO: move this to new llm_arch_model_i interface
397
+ llm_graph_result_ptr build_graph(
398
+ const llm_graph_params & params,
399
+ ggml_cgraph * gf,
400
+ llm_graph_type type) const;
401
+
402
  private:
403
  struct impl;
404
  std::unique_ptr<impl> pimpl;
405
  };
406
 
407
  const char * llm_type_name(llm_type type);
408
+
409
+ // For internal test use
410
+ // TODO: remove
411
+ const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);
examples/talk-llama/llama-quant.cpp CHANGED
@@ -10,6 +10,7 @@
10
  #include <cinttypes>
11
  #include <fstream>
12
  #include <mutex>
 
13
  #include <thread>
14
  #include <unordered_map>
15
 
@@ -47,8 +48,14 @@ struct quantize_state_impl {
47
  {}
48
  };
49
 
 
 
 
 
 
 
50
  static void llama_tensor_dequantize_impl(
51
- struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
52
  const size_t nelements, const int nthread
53
  ) {
54
  if (output.size() < nelements) {
@@ -527,7 +534,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
527
  }
528
 
529
  std::vector<std::string> splits = {};
530
- llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
531
  ml.init_mappings(false); // no prefetching
532
 
533
  llama_model model(llama_model_default_params());
@@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
536
  model.load_hparams(ml);
537
  model.load_stats (ml);
538
 
539
- struct quantize_state_impl qs(model, params);
540
 
541
  if (params->only_copy) {
542
  ftype = ml.ftype;
@@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
661
  // populate the original tensors so we get an initial meta data
662
  for (const auto * it : tensors) {
663
  uint16_t i_split = params->keep_split ? it->idx : 0;
664
- struct ggml_tensor * tensor = it->tensor;
665
  if (!ctx_outs[i_split]) {
666
  ctx_outs[i_split].reset(gguf_init_empty());
667
  }
@@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
710
  new_ofstream(0);
711
  for (const auto * it : tensors) {
712
  const auto & weight = *it;
713
- struct ggml_tensor * tensor = weight.tensor;
714
  if (weight.idx != cur_split && params->keep_split) {
715
  close_ofstream();
716
  new_ofstream(weight.idx);
@@ -756,10 +763,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
756
  // NOTE: can't use LLM_TN here because the layer number is not known
757
  quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
758
 
759
- // do not quantize RWKV's time_mix_first tensors
760
  quantize &= name.find("time_mix_first.weight") == std::string::npos;
 
761
  quantize &= name.find("time_mix_w1.weight") == std::string::npos;
762
  quantize &= name.find("time_mix_w2.weight") == std::string::npos;
 
 
 
 
 
 
 
 
763
  quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
764
  quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
765
  quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
@@ -767,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
767
  // do not quantize relative position bias (T5)
768
  quantize &= name.find("attn_rel_b.weight") == std::string::npos;
769
 
770
- enum ggml_type new_type;
771
  void * new_data;
772
  size_t new_size;
773
 
@@ -777,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
777
  // get more optimal quantization type based on the tensor shape, layer, etc.
778
  if (!params->pure && ggml_is_quantized(default_type)) {
779
  new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  }
781
  if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
782
  new_type = params->token_embedding_type;
@@ -901,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
901
  // interface implementation
902
  //
903
 
904
- struct llama_model_quantize_params llama_model_quantize_default_params() {
905
- struct llama_model_quantize_params result = {
906
  /*.nthread =*/ 0,
907
  /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
908
  /*.output_tensor_type =*/ GGML_TYPE_COUNT,
@@ -914,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
914
  /*.keep_split =*/ false,
915
  /*.imatrix =*/ nullptr,
916
  /*.kv_overrides =*/ nullptr,
 
917
  };
918
 
919
  return result;
 
10
  #include <cinttypes>
11
  #include <fstream>
12
  #include <mutex>
13
+ #include <regex>
14
  #include <thread>
15
  #include <unordered_map>
16
 
 
48
  {}
49
  };
50
 
51
+ // changes to this struct must be replicated in quantize.cpp
52
+ struct tensor_quantization {
53
+ std::string name;
54
+ ggml_type quant = GGML_TYPE_COUNT;
55
+ };
56
+
57
  static void llama_tensor_dequantize_impl(
58
+ ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
59
  const size_t nelements, const int nthread
60
  ) {
61
  if (output.size() < nelements) {
 
534
  }
535
 
536
  std::vector<std::string> splits = {};
537
+ llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
538
  ml.init_mappings(false); // no prefetching
539
 
540
  llama_model model(llama_model_default_params());
 
543
  model.load_hparams(ml);
544
  model.load_stats (ml);
545
 
546
+ quantize_state_impl qs(model, params);
547
 
548
  if (params->only_copy) {
549
  ftype = ml.ftype;
 
668
  // populate the original tensors so we get an initial meta data
669
  for (const auto * it : tensors) {
670
  uint16_t i_split = params->keep_split ? it->idx : 0;
671
+ ggml_tensor * tensor = it->tensor;
672
  if (!ctx_outs[i_split]) {
673
  ctx_outs[i_split].reset(gguf_init_empty());
674
  }
 
717
  new_ofstream(0);
718
  for (const auto * it : tensors) {
719
  const auto & weight = *it;
720
+ ggml_tensor * tensor = weight.tensor;
721
  if (weight.idx != cur_split && params->keep_split) {
722
  close_ofstream();
723
  new_ofstream(weight.idx);
 
763
  // NOTE: can't use LLM_TN here because the layer number is not known
764
  quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
765
 
766
+ // do not quantize RWKV's small yet 2D weights
767
  quantize &= name.find("time_mix_first.weight") == std::string::npos;
768
+ quantize &= name.find("time_mix_w0.weight") == std::string::npos;
769
  quantize &= name.find("time_mix_w1.weight") == std::string::npos;
770
  quantize &= name.find("time_mix_w2.weight") == std::string::npos;
771
+ quantize &= name.find("time_mix_v0.weight") == std::string::npos;
772
+ quantize &= name.find("time_mix_v1.weight") == std::string::npos;
773
+ quantize &= name.find("time_mix_v2.weight") == std::string::npos;
774
+ quantize &= name.find("time_mix_a0.weight") == std::string::npos;
775
+ quantize &= name.find("time_mix_a1.weight") == std::string::npos;
776
+ quantize &= name.find("time_mix_a2.weight") == std::string::npos;
777
+ quantize &= name.find("time_mix_g1.weight") == std::string::npos;
778
+ quantize &= name.find("time_mix_g2.weight") == std::string::npos;
779
  quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
780
  quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
781
  quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
 
783
  // do not quantize relative position bias (T5)
784
  quantize &= name.find("attn_rel_b.weight") == std::string::npos;
785
 
786
+ ggml_type new_type;
787
  void * new_data;
788
  size_t new_size;
789
 
 
793
  // get more optimal quantization type based on the tensor shape, layer, etc.
794
  if (!params->pure && ggml_is_quantized(default_type)) {
795
  new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
796
+ // unless the user specifies a type
797
+ if (params->tensor_types) {
798
+ const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
799
+ for (const auto & [tname, qtype] : tensor_types) {
800
+ if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
801
+ if (qtype != new_type) {
802
+ LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
803
+ }
804
+ new_type = qtype;
805
+ break;
806
+ }
807
+ }
808
+ }
809
  }
810
  if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
811
  new_type = params->token_embedding_type;
 
930
  // interface implementation
931
  //
932
 
933
+ llama_model_quantize_params llama_model_quantize_default_params() {
934
+ llama_model_quantize_params result = {
935
  /*.nthread =*/ 0,
936
  /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
937
  /*.output_tensor_type =*/ GGML_TYPE_COUNT,
 
943
  /*.keep_split =*/ false,
944
  /*.imatrix =*/ nullptr,
945
  /*.kv_overrides =*/ nullptr,
946
+ /*.tensor_type =*/ nullptr,
947
  };
948
 
949
  return result;
examples/talk-llama/llama-sampling.cpp CHANGED
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
316
 
317
  // llama_sampler API
318
 
 
 
 
 
 
 
 
319
  const char * llama_sampler_name(const struct llama_sampler * smpl) {
320
  if (!smpl->iface) {
321
  return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
347
  }
348
 
349
  if (smpl->ctx == nullptr) {
350
- return new llama_sampler {
351
  /* .iface = */ smpl->iface,
352
- /* .ctx = */ nullptr,
353
- };
354
  }
355
 
356
  GGML_ABORT("the sampler does not support cloning");
@@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
472
  };
473
 
474
  struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
475
- return new llama_sampler {
476
  /* .iface = */ &llama_sampler_chain_i,
477
  /* .ctx = */ new llama_sampler_chain {
478
  /* .params = */ params,
479
  /* .samplers = */ {},
480
  /* .t_sample_us = */ 0,
481
  /* .n_sample = */ 0,
482
- },
483
- };
484
  }
485
 
486
  void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
546
  };
547
 
548
  struct llama_sampler * llama_sampler_init_greedy() {
549
- return new llama_sampler {
550
  /* .iface = */ &llama_sampler_greedy_i,
551
- /* .ctx = */ nullptr,
552
- };
553
  }
554
 
555
  // dist
@@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
608
 
609
  struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
610
  auto seed_cur = get_rng_seed(seed);
611
- return new llama_sampler {
612
  /* .iface = */ &llama_sampler_dist_i,
613
  /* .ctx = */ new llama_sampler_dist {
614
  /* .seed = */ seed,
615
  /* .seed_cur = */ seed_cur,
616
  /* .rng = */ std::mt19937(seed_cur),
617
- },
618
- };
619
  }
620
 
621
  // softmax
@@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
638
  };
639
 
640
  struct llama_sampler * llama_sampler_init_softmax() {
641
- return new llama_sampler {
642
  /* .iface = */ &llama_sampler_softmax_i,
643
- /* .ctx = */ nullptr,
644
- };
645
  }
646
 
647
  // top-k
@@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
678
  };
679
 
680
  struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
681
- return new llama_sampler {
682
  /* .iface = */ &llama_sampler_top_k_i,
683
  /* .ctx = */ new llama_sampler_top_k {
684
  /* .k = */ k,
685
- },
686
- };
687
  }
688
 
689
  // top-p
@@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
744
  };
745
 
746
  struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
747
- return new llama_sampler {
748
  /* .iface = */ &llama_sampler_top_p_i,
749
  /* .ctx = */ new llama_sampler_top_p {
750
  /* .p = */ p,
751
  /* .min_keep = */ min_keep,
752
- },
753
- };
754
  }
755
 
756
  // min-p
@@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
840
  };
841
 
842
  struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
843
- return new llama_sampler {
844
  /* .iface = */ &llama_sampler_min_p_i,
845
  /* .ctx = */ new llama_sampler_min_p {
846
  /* .p = */ p,
847
  /* .min_keep = */ min_keep,
848
- },
849
- };
850
  }
851
 
852
  // typical
@@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
939
  };
940
 
941
  struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
942
- return new llama_sampler {
943
  /* .iface = */ &llama_sampler_typical_i,
944
  /* .ctx = */ new llama_sampler_typical {
945
  /* .p = */ p,
946
  /* .min_keep = */ min_keep,
947
- },
948
- };
949
  }
950
 
951
  // temp
@@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
983
  };
984
 
985
  struct llama_sampler * llama_sampler_init_temp(float temp) {
986
- return new llama_sampler {
987
  /* .iface = */ &llama_sampler_temp_i,
988
  /* .ctx = */ new llama_sampler_temp {
989
  /*.temp = */ temp,
990
- },
991
- };
992
  }
993
 
994
  // temp-ext
@@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
1093
  };
1094
 
1095
  struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1096
- return new llama_sampler {
1097
  /* .iface = */ &llama_sampler_temp_ext_i,
1098
  /* .ctx = */ new llama_sampler_temp_ext {
1099
  /* .temp = */ temp,
1100
  /* .delta = */ delta,
1101
  /* .exponent = */ exponent,
1102
- },
1103
- };
1104
  }
1105
 
1106
  // xtc
@@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
1185
 
1186
  struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1187
  auto seed_cur = get_rng_seed(seed);
1188
- return new llama_sampler {
1189
  /* .iface = */ &llama_sampler_xtc_i,
1190
  /* .ctx = */ new llama_sampler_xtc {
1191
  /* .probability = */ p,
@@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
1194
  /* .seed = */ seed,
1195
  /* .seed_cur = */ seed_cur,
1196
  /* .rng = */ std::mt19937(seed_cur),
1197
- },
1198
- };
1199
  }
1200
 
1201
  // mirostat
@@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
1292
 
1293
  struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1294
  auto seed_cur = get_rng_seed(seed);
1295
- return new llama_sampler {
1296
  /* .iface = */ &llama_sampler_mirostat_i,
1297
  /* .ctx = */ new llama_sampler_mirostat {
1298
  /* .n_vocab = */ n_vocab,
@@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
1303
  /* .m = */ m,
1304
  /* .mu = */ 2.0f*tau,
1305
  /* .rng = */ std::mt19937(seed_cur),
1306
- },
1307
- };
1308
  }
1309
 
1310
  // mirostat v2
@@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1391
 
1392
  struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1393
  auto seed_cur = get_rng_seed(seed);
1394
- return new llama_sampler {
1395
  /* .iface = */ &llama_sampler_mirostat_v2_i,
1396
  /* .ctx = */ new llama_sampler_mirostat_v2 {
1397
  /* .seed = */ seed,
@@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
1400
  /* .eta = */ eta,
1401
  /* .mu = */ 2.0f*tau,
1402
  /* .rng = */ std::mt19937(seed_cur),
1403
- },
1404
- };
1405
  }
1406
 
1407
  // grammar
@@ -1442,7 +1449,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1442
  const char ** trigger_words,
1443
  size_t num_trigger_words,
1444
  const llama_token * trigger_tokens,
1445
- size_t num_trigger_tokens);
 
 
1446
 
1447
  static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1448
  auto * ctx = (llama_sampler_grammar *) smpl->ctx;
@@ -1450,12 +1459,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1450
  return;
1451
  }
1452
 
1453
- std::vector<const char *> trigger_words;
1454
- for (auto & word : ctx->grammar->trigger_words) {
1455
- trigger_words.push_back(word.c_str());
 
1456
  }
 
1457
  auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
1458
- ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
1459
  ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
1460
 
1461
  llama_grammar_free_impl(ctx->grammar);
@@ -1465,7 +1476,8 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1465
  static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1466
  const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1467
 
1468
- auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
 
1469
 
1470
  // copy the state
1471
  {
@@ -1509,16 +1521,38 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1509
  const char ** trigger_words,
1510
  size_t num_trigger_words,
1511
  const llama_token * trigger_tokens,
1512
- size_t num_trigger_tokens) {
 
 
1513
  auto * ctx = new llama_sampler_grammar;
1514
 
1515
  if (grammar_str != nullptr && grammar_str[0] != '\0') {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1516
  *ctx = {
1517
  /* .vocab = */ vocab,
1518
  /* .grammar_str = */ grammar_str,
1519
  /* .grammar_root = */ grammar_root,
1520
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
1521
  };
 
 
 
 
1522
  } else {
1523
  *ctx = {
1524
  /* .vocab = */ vocab,
@@ -1528,17 +1562,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
1528
  };
1529
  }
1530
 
1531
- return new llama_sampler {
1532
  /* .iface = */ &llama_sampler_grammar_i,
1533
- /* .ctx = */ ctx,
1534
- };
1535
  }
1536
 
1537
  struct llama_sampler * llama_sampler_init_grammar(
1538
  const struct llama_vocab * vocab,
1539
  const char * grammar_str,
1540
  const char * grammar_root) {
1541
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
1542
  }
1543
 
1544
  struct llama_sampler * llama_sampler_init_grammar_lazy(
@@ -1549,7 +1583,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy(
1549
  size_t num_trigger_words,
1550
  const llama_token * trigger_tokens,
1551
  size_t num_trigger_tokens) {
1552
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
 
 
 
 
 
 
 
 
 
 
 
1553
  }
1554
 
1555
  // penalties
@@ -1678,7 +1723,7 @@ struct llama_sampler * llama_sampler_init_penalties(
1678
  float penalty_present) {
1679
  penalty_last_n = std::max(penalty_last_n, 0);
1680
 
1681
- return new llama_sampler {
1682
  /* .iface = */ &llama_sampler_penalties_i,
1683
  /* .ctx = */ new llama_sampler_penalties {
1684
  /* .penalty_last_n = */ penalty_last_n,
@@ -1687,8 +1732,75 @@ struct llama_sampler * llama_sampler_init_penalties(
1687
  /* .penalty_present = */ penalty_present,
1688
  /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1689
  /* .token_count = */ {},
1690
- },
1691
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1692
  }
1693
 
1694
  // DRY
@@ -2041,7 +2153,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
2041
  }
2042
  }
2043
 
2044
- return new llama_sampler {
2045
  /* .iface = */ &llama_sampler_dry_i,
2046
  /* .ctx = */ new llama_sampler_dry {
2047
  /* .total_context_size = */ context_size,
@@ -2053,8 +2165,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
2053
  /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2054
  /* .dry_max_token_repeat = */ {},
2055
  /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2056
- },
2057
- };
2058
  }
2059
 
2060
  // wrapper for test-sampling.cpp
@@ -2155,14 +2267,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
2155
  int32_t n_vocab,
2156
  int32_t n_logit_bias,
2157
  const llama_logit_bias * logit_bias) {
2158
- return new llama_sampler {
2159
  /* .iface = */ &llama_sampler_logit_bias_i,
2160
  /* .ctx = */ new llama_sampler_logit_bias {
2161
  /* .n_vocab = */ n_vocab,
2162
  /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2163
  /* .to_search = */ {},
2164
- },
2165
- };
2166
  }
2167
 
2168
  // infill
@@ -2377,14 +2489,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
2377
  };
2378
 
2379
  struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2380
- return new llama_sampler {
2381
  /* .iface = */ &llama_sampler_infill_i,
2382
  /* .ctx = */ new llama_sampler_infill {
2383
  /* .vocab = */ vocab,
2384
  /* .buf0 = */ std::vector<char>(512),
2385
  /* .buf1 = */ std::vector<char>(512),
2386
- },
2387
- };
2388
  }
2389
 
2390
  // utils
 
316
 
317
  // llama_sampler API
318
 
319
+ struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
320
+ return new llama_sampler {
321
+ /* .iface = */ iface,
322
+ /* .ctx = */ ctx,
323
+ };
324
+ }
325
+
326
  const char * llama_sampler_name(const struct llama_sampler * smpl) {
327
  if (!smpl->iface) {
328
  return "(null)";
 
354
  }
355
 
356
  if (smpl->ctx == nullptr) {
357
+ return llama_sampler_init(
358
  /* .iface = */ smpl->iface,
359
+ /* .ctx = */ nullptr
360
+ );
361
  }
362
 
363
  GGML_ABORT("the sampler does not support cloning");
 
479
  };
480
 
481
  struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
482
+ return llama_sampler_init(
483
  /* .iface = */ &llama_sampler_chain_i,
484
  /* .ctx = */ new llama_sampler_chain {
485
  /* .params = */ params,
486
  /* .samplers = */ {},
487
  /* .t_sample_us = */ 0,
488
  /* .n_sample = */ 0,
489
+ }
490
+ );
491
  }
492
 
493
  void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
 
553
  };
554
 
555
  struct llama_sampler * llama_sampler_init_greedy() {
556
+ return llama_sampler_init(
557
  /* .iface = */ &llama_sampler_greedy_i,
558
+ /* .ctx = */ nullptr
559
+ );
560
  }
561
 
562
  // dist
 
615
 
616
  struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
617
  auto seed_cur = get_rng_seed(seed);
618
+ return llama_sampler_init(
619
  /* .iface = */ &llama_sampler_dist_i,
620
  /* .ctx = */ new llama_sampler_dist {
621
  /* .seed = */ seed,
622
  /* .seed_cur = */ seed_cur,
623
  /* .rng = */ std::mt19937(seed_cur),
624
+ }
625
+ );
626
  }
627
 
628
  // softmax
 
645
  };
646
 
647
  struct llama_sampler * llama_sampler_init_softmax() {
648
+ return llama_sampler_init(
649
  /* .iface = */ &llama_sampler_softmax_i,
650
+ /* .ctx = */ nullptr
651
+ );
652
  }
653
 
654
  // top-k
 
685
  };
686
 
687
  struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
688
+ return llama_sampler_init(
689
  /* .iface = */ &llama_sampler_top_k_i,
690
  /* .ctx = */ new llama_sampler_top_k {
691
  /* .k = */ k,
692
+ }
693
+ );
694
  }
695
 
696
  // top-p
 
751
  };
752
 
753
  struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
754
+ return llama_sampler_init(
755
  /* .iface = */ &llama_sampler_top_p_i,
756
  /* .ctx = */ new llama_sampler_top_p {
757
  /* .p = */ p,
758
  /* .min_keep = */ min_keep,
759
+ }
760
+ );
761
  }
762
 
763
  // min-p
 
847
  };
848
 
849
  struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
850
+ return llama_sampler_init(
851
  /* .iface = */ &llama_sampler_min_p_i,
852
  /* .ctx = */ new llama_sampler_min_p {
853
  /* .p = */ p,
854
  /* .min_keep = */ min_keep,
855
+ }
856
+ );
857
  }
858
 
859
  // typical
 
946
  };
947
 
948
  struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
949
+ return llama_sampler_init(
950
  /* .iface = */ &llama_sampler_typical_i,
951
  /* .ctx = */ new llama_sampler_typical {
952
  /* .p = */ p,
953
  /* .min_keep = */ min_keep,
954
+ }
955
+ );
956
  }
957
 
958
  // temp
 
990
  };
991
 
992
  struct llama_sampler * llama_sampler_init_temp(float temp) {
993
+ return llama_sampler_init(
994
  /* .iface = */ &llama_sampler_temp_i,
995
  /* .ctx = */ new llama_sampler_temp {
996
  /*.temp = */ temp,
997
+ }
998
+ );
999
  }
1000
 
1001
  // temp-ext
 
1100
  };
1101
 
1102
  struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1103
+ return llama_sampler_init(
1104
  /* .iface = */ &llama_sampler_temp_ext_i,
1105
  /* .ctx = */ new llama_sampler_temp_ext {
1106
  /* .temp = */ temp,
1107
  /* .delta = */ delta,
1108
  /* .exponent = */ exponent,
1109
+ }
1110
+ );
1111
  }
1112
 
1113
  // xtc
 
1192
 
1193
  struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1194
  auto seed_cur = get_rng_seed(seed);
1195
+ return llama_sampler_init(
1196
  /* .iface = */ &llama_sampler_xtc_i,
1197
  /* .ctx = */ new llama_sampler_xtc {
1198
  /* .probability = */ p,
 
1201
  /* .seed = */ seed,
1202
  /* .seed_cur = */ seed_cur,
1203
  /* .rng = */ std::mt19937(seed_cur),
1204
+ }
1205
+ );
1206
  }
1207
 
1208
  // mirostat
 
1299
 
1300
  struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1301
  auto seed_cur = get_rng_seed(seed);
1302
+ return llama_sampler_init(
1303
  /* .iface = */ &llama_sampler_mirostat_i,
1304
  /* .ctx = */ new llama_sampler_mirostat {
1305
  /* .n_vocab = */ n_vocab,
 
1310
  /* .m = */ m,
1311
  /* .mu = */ 2.0f*tau,
1312
  /* .rng = */ std::mt19937(seed_cur),
1313
+ }
1314
+ );
1315
  }
1316
 
1317
  // mirostat v2
 
1398
 
1399
  struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1400
  auto seed_cur = get_rng_seed(seed);
1401
+ return llama_sampler_init(
1402
  /* .iface = */ &llama_sampler_mirostat_v2_i,
1403
  /* .ctx = */ new llama_sampler_mirostat_v2 {
1404
  /* .seed = */ seed,
 
1407
  /* .eta = */ eta,
1408
  /* .mu = */ 2.0f*tau,
1409
  /* .rng = */ std::mt19937(seed_cur),
1410
+ }
1411
+ );
1412
  }
1413
 
1414
  // grammar
 
1449
  const char ** trigger_words,
1450
  size_t num_trigger_words,
1451
  const llama_token * trigger_tokens,
1452
+ size_t num_trigger_tokens,
1453
+ const char ** trigger_patterns,
1454
+ size_t num_trigger_patterns);
1455
 
1456
  static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1457
  auto * ctx = (llama_sampler_grammar *) smpl->ctx;
 
1459
  return;
1460
  }
1461
 
1462
+ std::vector<const char *> trigger_patterns_c;
1463
+ trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
1464
+ for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
1465
+ trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
1466
  }
1467
+
1468
  auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
1469
+ ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
1470
  ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
1471
 
1472
  llama_grammar_free_impl(ctx->grammar);
 
1476
  static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1477
  const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1478
 
1479
+ auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
1480
+ GGML_ASSERT(result);
1481
 
1482
  // copy the state
1483
  {
 
1521
  const char ** trigger_words,
1522
  size_t num_trigger_words,
1523
  const llama_token * trigger_tokens,
1524
+ size_t num_trigger_tokens,
1525
+ const char ** trigger_patterns,
1526
+ size_t num_trigger_patterns) {
1527
  auto * ctx = new llama_sampler_grammar;
1528
 
1529
  if (grammar_str != nullptr && grammar_str[0] != '\0') {
1530
+ // TODO: remove trigger_words support.
1531
+ if (trigger_words != nullptr && num_trigger_words > 0) {
1532
+ GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1533
+ std::string trigger_pattern("[\\s\\S]*?(");
1534
+ for (size_t i = 0; i < num_trigger_words; ++i) {
1535
+ static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1536
+ if (i > 0) {
1537
+ trigger_pattern += "|";
1538
+ }
1539
+ trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
1540
+ }
1541
+ trigger_pattern += ")[\\s\\S]*";
1542
+ auto trigger_pattern_c = trigger_pattern.c_str();
1543
+ trigger_patterns = &trigger_pattern_c;
1544
+ num_trigger_patterns = 1;
1545
+ }
1546
  *ctx = {
1547
  /* .vocab = */ vocab,
1548
  /* .grammar_str = */ grammar_str,
1549
  /* .grammar_root = */ grammar_root,
1550
+ /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
1551
  };
1552
+ if (!ctx->grammar) {
1553
+ delete ctx;
1554
+ return nullptr;
1555
+ }
1556
  } else {
1557
  *ctx = {
1558
  /* .vocab = */ vocab,
 
1562
  };
1563
  }
1564
 
1565
+ return llama_sampler_init(
1566
  /* .iface = */ &llama_sampler_grammar_i,
1567
+ /* .ctx = */ ctx
1568
+ );
1569
  }
1570
 
1571
  struct llama_sampler * llama_sampler_init_grammar(
1572
  const struct llama_vocab * vocab,
1573
  const char * grammar_str,
1574
  const char * grammar_root) {
1575
+ return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
1576
  }
1577
 
1578
  struct llama_sampler * llama_sampler_init_grammar_lazy(
 
1583
  size_t num_trigger_words,
1584
  const llama_token * trigger_tokens,
1585
  size_t num_trigger_tokens) {
1586
+ return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
1587
+ }
1588
+
1589
+ struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
1590
+ const struct llama_vocab * vocab,
1591
+ const char * grammar_str,
1592
+ const char * grammar_root,
1593
+ const char ** trigger_patterns,
1594
+ size_t num_trigger_patterns,
1595
+ const llama_token * trigger_tokens,
1596
+ size_t num_trigger_tokens) {
1597
+ return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
1598
  }
1599
 
1600
  // penalties
 
1723
  float penalty_present) {
1724
  penalty_last_n = std::max(penalty_last_n, 0);
1725
 
1726
+ return llama_sampler_init(
1727
  /* .iface = */ &llama_sampler_penalties_i,
1728
  /* .ctx = */ new llama_sampler_penalties {
1729
  /* .penalty_last_n = */ penalty_last_n,
 
1732
  /* .penalty_present = */ penalty_present,
1733
  /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1734
  /* .token_count = */ {},
1735
+ }
1736
+ );
1737
+ }
1738
+
1739
+ // top-n-sigma
1740
+
1741
+ struct llama_sampler_top_n_sigma {
1742
+ const float n;
1743
+ };
1744
+
1745
+ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1746
+ return "top-n-sigma";
1747
+ }
1748
+
1749
+ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1750
+ const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1751
+
1752
+ // find max logit and calculate mean
1753
+ float max = cur_p->data[0].logit;
1754
+ float logits_sum = 0;
1755
+ for (size_t i = 0; i < cur_p->size; ++i) {
1756
+ if (cur_p->data[i].logit > max) {
1757
+ max = cur_p->data[i].logit;
1758
+ }
1759
+ logits_sum += cur_p->data[i].logit;
1760
+ }
1761
+ float mean = logits_sum/cur_p->size;
1762
+
1763
+ // calculate standard deviation
1764
+ float acc = 0;
1765
+ for (size_t i = 0; i < cur_p->size; ++i) {
1766
+ acc += pow(cur_p->data[i].logit - mean, 2);
1767
+ }
1768
+ float std = sqrt(acc/cur_p->size);
1769
+
1770
+ //apply mask
1771
+ for (size_t i = 0; i < cur_p->size; ++i) {
1772
+ if (cur_p->data[i].logit < max - (ctx->n * std)) {
1773
+ cur_p->data[i].logit = -INFINITY;
1774
+ }
1775
+ }
1776
+ llama_sampler_softmax_impl(cur_p);
1777
+ }
1778
+
1779
+ static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
1780
+ const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1781
+ return llama_sampler_init_top_n_sigma(ctx->n);
1782
+ }
1783
+
1784
+ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1785
+ delete (llama_sampler_top_n_sigma *) smpl->ctx;
1786
+ }
1787
+
1788
+ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1789
+ /* .name = */ llama_sampler_top_n_sigma_name,
1790
+ /* .accept = */ nullptr,
1791
+ /* .apply = */ llama_sampler_top_n_sigma_apply,
1792
+ /* .reset = */ nullptr,
1793
+ /* .clone = */ llama_sampler_top_n_sigma_clone,
1794
+ /* .free = */ llama_sampler_top_n_sigma_free,
1795
+ };
1796
+
1797
+ struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
1798
+ return llama_sampler_init(
1799
+ /* .iface = */ &llama_sampler_top_n_sigma_i,
1800
+ /* .ctx = */ new llama_sampler_top_n_sigma {
1801
+ /* .n = */ n,
1802
+ }
1803
+ );
1804
  }
1805
 
1806
  // DRY
 
2153
  }
2154
  }
2155
 
2156
+ return llama_sampler_init(
2157
  /* .iface = */ &llama_sampler_dry_i,
2158
  /* .ctx = */ new llama_sampler_dry {
2159
  /* .total_context_size = */ context_size,
 
2165
  /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2166
  /* .dry_max_token_repeat = */ {},
2167
  /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2168
+ }
2169
+ );
2170
  }
2171
 
2172
  // wrapper for test-sampling.cpp
 
2267
  int32_t n_vocab,
2268
  int32_t n_logit_bias,
2269
  const llama_logit_bias * logit_bias) {
2270
+ return llama_sampler_init(
2271
  /* .iface = */ &llama_sampler_logit_bias_i,
2272
  /* .ctx = */ new llama_sampler_logit_bias {
2273
  /* .n_vocab = */ n_vocab,
2274
  /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2275
  /* .to_search = */ {},
2276
+ }
2277
+ );
2278
  }
2279
 
2280
  // infill
 
2489
  };
2490
 
2491
  struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2492
+ return llama_sampler_init(
2493
  /* .iface = */ &llama_sampler_infill_i,
2494
  /* .ctx = */ new llama_sampler_infill {
2495
  /* .vocab = */ vocab,
2496
  /* .buf0 = */ std::vector<char>(512),
2497
  /* .buf1 = */ std::vector<char>(512),
2498
+ }
2499
+ );
2500
  }
2501
 
2502
  // utils
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -16,6 +16,7 @@
16
  #include <queue>
17
  #include <set>
18
  #include <unordered_map>
 
19
 
20
  //
21
  // helpers
@@ -341,6 +342,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
341
  case LLAMA_VOCAB_PRE_TYPE_MPT:
342
  case LLAMA_VOCAB_PRE_TYPE_OLMO:
343
  case LLAMA_VOCAB_PRE_TYPE_JAIS:
 
344
  regex_exprs = {
345
  "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
346
  };
@@ -392,6 +394,27 @@ struct llm_tokenizer_bpe : llm_tokenizer {
392
  "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
393
  };
394
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  default:
396
  // default regex for BPE tokenization pre-processing
397
  regex_exprs = {
@@ -1483,7 +1506,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1483
  tokenizer_pre == "llama3" ||
1484
  tokenizer_pre == "llama-v3" ||
1485
  tokenizer_pre == "llama-bpe"||
1486
- tokenizer_pre == "falcon3") {
 
1487
  pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
1488
  ignore_merges = true;
1489
  add_bos = true;
@@ -1549,6 +1573,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1549
  pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
1550
  clean_spaces = false;
1551
  } else if (
 
1552
  tokenizer_pre == "chatglm-bpe") {
1553
  pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
1554
  special_bos_id = LLAMA_TOKEN_NULL;
@@ -1592,6 +1617,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1592
  } else if (
1593
  tokenizer_pre == "megrez") {
1594
  pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1595
  } else {
1596
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1597
  }
@@ -1769,6 +1811,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1769
  || t.first == "<end_of_turn>"
1770
  || t.first == "<|endoftext|>"
1771
  || t.first == "<EOT>"
 
1772
  || t.first == "<|end▁of▁sentence|>" // DeepSeek
1773
  ) {
1774
  special_eot_id = t.second;
@@ -1799,8 +1842,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1799
  if (false
1800
  || t.first == "<|fim_prefix|>" // Qwen
1801
  || t.first == "<fim-prefix>"
 
1802
  || t.first == "<|fim▁begin|>" // DeepSeek
1803
  || t.first == "<PRE>"
 
1804
  ) {
1805
  special_fim_pre_id = t.second;
1806
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1816,8 +1861,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1816
  if (false
1817
  || t.first == "<|fim_suffix|>" // Qwen
1818
  || t.first == "<fim-suffix>"
 
1819
  || t.first == "<|fim▁hole|>" // DeepSeek
1820
  || t.first == "<SUF>"
 
1821
  ) {
1822
  special_fim_suf_id = t.second;
1823
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1833,8 +1880,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1833
  if (false
1834
  || t.first == "<|fim_middle|>" // Qwen
1835
  || t.first == "<fim-middle>"
 
1836
  || t.first == "<|fim▁end|>" // DeepSeek
1837
  || t.first == "<MID>"
 
1838
  ) {
1839
  special_fim_mid_id = t.second;
1840
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1850,6 +1899,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1850
  if (false
1851
  || t.first == "<|fim_pad|>" // Qwen
1852
  || t.first == "<fim-pad>"
 
1853
  || t.first == "<PAD>"
1854
  ) {
1855
  special_fim_pad_id = t.second;
@@ -1868,6 +1918,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1868
  || t.first == "<|repo_name|>"
1869
  || t.first == "<fim-repo>"
1870
  || t.first == "<REPO>"
 
1871
  ) {
1872
  special_fim_rep_id = t.second;
1873
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1919,6 +1970,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1919
  || t.first == "<|endoftext|>"
1920
  || t.first == "<|eom_id|>"
1921
  || t.first == "<EOT>"
 
1922
  ) {
1923
  special_eog_ids.insert(t.second);
1924
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2177,14 +2229,12 @@ void llama_vocab::impl::tokenizer_st_partition(std::forward_list<fragment_buffer
2177
  // find the first occurrence of a given special token in this fragment
2178
  // passing offset argument only limit the "search area" but match coordinates
2179
  // are still relative to the source full raw_text
2180
- auto match = raw_text.find(text, raw_text_base_offset);
 
2181
 
2182
  // no occurrences found, stop processing this fragment for a given special token
2183
  if (match == std::string::npos) break;
2184
 
2185
- // check if match is within bounds of offset <-> length
2186
- if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
2187
-
2188
  #ifdef PRETOKENIZERDEBUG
2189
  LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
2190
  #endif
 
16
  #include <queue>
17
  #include <set>
18
  #include <unordered_map>
19
+ #include <cctype>
20
 
21
  //
22
  // helpers
 
342
  case LLAMA_VOCAB_PRE_TYPE_MPT:
343
  case LLAMA_VOCAB_PRE_TYPE_OLMO:
344
  case LLAMA_VOCAB_PRE_TYPE_JAIS:
345
+ case LLAMA_VOCAB_PRE_TYPE_TRILLION:
346
  regex_exprs = {
347
  "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
348
  };
 
394
  "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
395
  };
396
  break;
397
+ case LLAMA_VOCAB_PRE_TYPE_GPT4O:
398
+ regex_exprs = {
399
+ // original regex from tokenizer.json
400
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
401
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
402
+ };
403
+ break;
404
+ case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
405
+ regex_exprs = {
406
+ "\\p{N}+",
407
+ "(?=(\\d{3})+(?!\\d))",
408
+ };
409
+ break;
410
+ case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
411
+ regex_exprs = {
412
+ // original regex from tokenizer.json
413
+ // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
414
+ // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?)
415
+ "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
416
+ };
417
+ break;
418
  default:
419
  // default regex for BPE tokenization pre-processing
420
  regex_exprs = {
 
1506
  tokenizer_pre == "llama3" ||
1507
  tokenizer_pre == "llama-v3" ||
1508
  tokenizer_pre == "llama-bpe"||
1509
+ tokenizer_pre == "falcon3" ||
1510
+ tokenizer_pre == "pixtral") {
1511
  pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
1512
  ignore_merges = true;
1513
  add_bos = true;
 
1573
  pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
1574
  clean_spaces = false;
1575
  } else if (
1576
+ tokenizer_pre == "glm4" ||
1577
  tokenizer_pre == "chatglm-bpe") {
1578
  pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
1579
  special_bos_id = LLAMA_TOKEN_NULL;
 
1617
  } else if (
1618
  tokenizer_pre == "megrez") {
1619
  pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
1620
+ } else if (
1621
+ tokenizer_pre == "gpt-4o" ||
1622
+ tokenizer_pre == "llama4") {
1623
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
1624
+ clean_spaces = false;
1625
+ } else if (
1626
+ tokenizer_pre == "superbpe") {
1627
+ pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
1628
+ clean_spaces = false;
1629
+ } else if (
1630
+ tokenizer_pre == "trillion") {
1631
+ pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
1632
+ clean_spaces = false;
1633
+ } else if (
1634
+ tokenizer_pre == "bailingmoe") {
1635
+ pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
1636
+ clean_spaces = false;
1637
  } else {
1638
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1639
  }
 
1811
  || t.first == "<end_of_turn>"
1812
  || t.first == "<|endoftext|>"
1813
  || t.first == "<EOT>"
1814
+ || t.first == "_<EOT>"
1815
  || t.first == "<|end▁of▁sentence|>" // DeepSeek
1816
  ) {
1817
  special_eot_id = t.second;
 
1842
  if (false
1843
  || t.first == "<|fim_prefix|>" // Qwen
1844
  || t.first == "<fim-prefix>"
1845
+ || t.first == "<fim_prefix>" // Granite
1846
  || t.first == "<|fim▁begin|>" // DeepSeek
1847
  || t.first == "<PRE>"
1848
+ || t.first == "▁<PRE>" // CodeLlama
1849
  ) {
1850
  special_fim_pre_id = t.second;
1851
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
1861
  if (false
1862
  || t.first == "<|fim_suffix|>" // Qwen
1863
  || t.first == "<fim-suffix>"
1864
+ || t.first == "<fim_suffix>" // Granite
1865
  || t.first == "<|fim▁hole|>" // DeepSeek
1866
  || t.first == "<SUF>"
1867
+ || t.first == "▁<SUF>" // CodeLlama
1868
  ) {
1869
  special_fim_suf_id = t.second;
1870
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
1880
  if (false
1881
  || t.first == "<|fim_middle|>" // Qwen
1882
  || t.first == "<fim-middle>"
1883
+ || t.first == "<fim_middle>" // Granite
1884
  || t.first == "<|fim▁end|>" // DeepSeek
1885
  || t.first == "<MID>"
1886
+ || t.first == "▁<MID>" // CodeLlama
1887
  ) {
1888
  special_fim_mid_id = t.second;
1889
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
1899
  if (false
1900
  || t.first == "<|fim_pad|>" // Qwen
1901
  || t.first == "<fim-pad>"
1902
+ || t.first == "<fim_pad>" // Granite
1903
  || t.first == "<PAD>"
1904
  ) {
1905
  special_fim_pad_id = t.second;
 
1918
  || t.first == "<|repo_name|>"
1919
  || t.first == "<fim-repo>"
1920
  || t.first == "<REPO>"
1921
+ || t.first == "<reponame>" // Granite
1922
  ) {
1923
  special_fim_rep_id = t.second;
1924
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
1970
  || t.first == "<|endoftext|>"
1971
  || t.first == "<|eom_id|>"
1972
  || t.first == "<EOT>"
1973
+ || t.first == "_<EOT>"
1974
  ) {
1975
  special_eog_ids.insert(t.second);
1976
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
2229
  // find the first occurrence of a given special token in this fragment
2230
  // passing offset argument only limit the "search area" but match coordinates
2231
  // are still relative to the source full raw_text
2232
+ // string_view begins at pos 0 for the same reason
2233
+ auto match = std::string_view(raw_text.data(), raw_text_base_offset + raw_text_base_length).find(text, raw_text_base_offset);
2234
 
2235
  // no occurrences found, stop processing this fragment for a given special token
2236
  if (match == std::string::npos) break;
2237
 
 
 
 
2238
  #ifdef PRETOKENIZERDEBUG
2239
  LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
2240
  #endif
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -60,6 +60,7 @@ extern "C" {
60
  struct llama_model;
61
  struct llama_context;
62
  struct llama_sampler;
 
63
 
64
  typedef int32_t llama_pos;
65
  typedef int32_t llama_token;
@@ -105,6 +106,12 @@ extern "C" {
105
  LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
106
  LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
107
  LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
 
 
 
 
 
 
108
  };
109
 
110
  enum llama_rope_type {
@@ -213,7 +220,7 @@ extern "C" {
213
  LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
214
  };
215
 
216
- // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
217
  typedef struct llama_token_data {
218
  llama_token id; // token id
219
  float logit; // log-odds of the token
@@ -275,10 +282,18 @@ extern "C" {
275
  };
276
  };
277
 
 
 
 
 
 
278
  struct llama_model_params {
279
  // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
280
  ggml_backend_dev_t * devices;
281
 
 
 
 
282
  int32_t n_gpu_layers; // number of layers to store in VRAM
283
  enum llama_split_mode split_mode; // how to split the model across multiple GPUs
284
 
@@ -307,7 +322,7 @@ extern "C" {
307
  };
308
 
309
  // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
310
- // https://github.com/ggerganov/llama.cpp/pull/7544
311
  struct llama_context_params {
312
  uint32_t n_ctx; // text context, 0 = from model
313
  uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
@@ -320,7 +335,7 @@ extern "C" {
320
  enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
321
  enum llama_attention_type attention_type; // attention type to use for embeddings
322
 
323
- // ref: https://github.com/ggerganov/llama.cpp/pull/2054
324
  float rope_freq_base; // RoPE base frequency, 0 = from model
325
  float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
326
  float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
@@ -353,17 +368,18 @@ extern "C" {
353
 
354
  // model quantization parameters
355
  typedef struct llama_model_quantize_params {
356
- int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
357
- enum llama_ftype ftype; // quantize to this llama_ftype
358
- enum ggml_type output_tensor_type; // output tensor type
359
- enum ggml_type token_embedding_type; // token embeddings tensor type
360
- bool allow_requantize; // allow quantizing non-f32/f16 tensors
361
- bool quantize_output_tensor; // quantize output.weight
362
- bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
363
- bool pure; // quantize all tensors to the default type
364
- bool keep_split; // quantize to the same number of shards
365
- void * imatrix; // pointer to importance matrix data
366
- void * kv_overrides; // pointer to vector containing overrides
 
367
  } llama_model_quantize_params;
368
 
369
  typedef struct llama_logit_bias {
@@ -385,7 +401,7 @@ extern "C" {
385
  struct llama_adapter_lora;
386
 
387
  // Helpers for getting default parameters
388
- // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
389
  LLAMA_API struct llama_model_params llama_model_default_params(void);
390
  LLAMA_API struct llama_context_params llama_context_default_params(void);
391
  LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
@@ -468,7 +484,8 @@ extern "C" {
468
  DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
469
 
470
  LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
471
- LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
 
472
 
473
  LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
474
  LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@@ -477,6 +494,7 @@ extern "C" {
477
  LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
478
  LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
479
  LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
 
480
 
481
  // Get the model's RoPE frequency scaling factor
482
  LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
@@ -584,7 +602,7 @@ extern "C" {
584
  // KV cache
585
  //
586
 
587
- // TODO: remove llama_kv_cache_view_* API
588
 
589
  // Information associated with an individual cell in the KV cache view.
590
  struct llama_kv_cache_view_cell {
@@ -639,13 +657,19 @@ extern "C" {
639
 
640
  // Returns the number of tokens in the KV cache (slow, use only for debug)
641
  // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
642
- LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
 
 
 
643
 
644
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
645
- LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
 
 
 
646
 
647
  // Clear the KV cache - both cell info is erased and KV data is zeroed
648
- LLAMA_API void llama_kv_cache_clear(
649
  struct llama_context * ctx);
650
 
651
  // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -653,7 +677,7 @@ extern "C" {
653
  // seq_id < 0 : match any sequence
654
  // p0 < 0 : [0, p1]
655
  // p1 < 0 : [p0, inf)
656
- LLAMA_API bool llama_kv_cache_seq_rm(
657
  struct llama_context * ctx,
658
  llama_seq_id seq_id,
659
  llama_pos p0,
@@ -663,7 +687,7 @@ extern "C" {
663
  // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
664
  // p0 < 0 : [0, p1]
665
  // p1 < 0 : [p0, inf)
666
- LLAMA_API void llama_kv_cache_seq_cp(
667
  struct llama_context * ctx,
668
  llama_seq_id seq_id_src,
669
  llama_seq_id seq_id_dst,
@@ -671,17 +695,17 @@ extern "C" {
671
  llama_pos p1);
672
 
673
  // Removes all tokens that do not belong to the specified sequence
674
- LLAMA_API void llama_kv_cache_seq_keep(
675
  struct llama_context * ctx,
676
  llama_seq_id seq_id);
677
 
678
  // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
679
  // If the KV cache is RoPEd, the KV data is updated accordingly:
680
  // - lazily on next llama_decode()
681
- // - explicitly with llama_kv_cache_update()
682
  // p0 < 0 : [0, p1]
683
  // p1 < 0 : [p0, inf)
684
- LLAMA_API void llama_kv_cache_seq_add(
685
  struct llama_context * ctx,
686
  llama_seq_id seq_id,
687
  llama_pos p0,
@@ -691,10 +715,10 @@ extern "C" {
691
  // Integer division of the positions by factor of `d > 1`
692
  // If the KV cache is RoPEd, the KV data is updated accordingly:
693
  // - lazily on next llama_decode()
694
- // - explicitly with llama_kv_cache_update()
695
  // p0 < 0 : [0, p1]
696
  // p1 < 0 : [p0, inf)
697
- LLAMA_API void llama_kv_cache_seq_div(
698
  struct llama_context * ctx,
699
  llama_seq_id seq_id,
700
  llama_pos p0,
@@ -702,24 +726,76 @@ extern "C" {
702
  int d);
703
 
704
  // Returns the largest position present in the KV cache for the specified sequence
705
- LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
706
  struct llama_context * ctx,
707
- llama_seq_id seq_id);
708
-
709
- // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
710
- // how to avoid this?
711
 
712
  // Defragment the KV cache
713
  // This will be applied:
714
  // - lazily on next llama_decode()
715
- // - explicitly with llama_kv_cache_update()
716
- LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
 
 
 
717
 
718
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
719
- LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
 
721
- // Check if the context supports KV cache shifting
722
- LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
723
 
724
  //
725
  // State / sessions
@@ -883,6 +959,10 @@ extern "C" {
883
  // If set to true, the model will only attend to the past tokens
884
  LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
885
 
 
 
 
 
886
  // Set abort callback
887
  LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
888
 
@@ -1040,7 +1120,7 @@ extern "C" {
1040
 
1041
  /// Apply chat template. Inspired by hf apply_chat_template() on python.
1042
  /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
1043
- /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
1044
  /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
1045
  /// @param chat Pointer to a list of multiple llama_chat_message
1046
  /// @param n_msg Number of llama_chat_message in this chat
@@ -1114,11 +1194,12 @@ extern "C" {
1114
  };
1115
 
1116
  struct llama_sampler {
1117
- struct llama_sampler_i * iface;
1118
- llama_sampler_context_t ctx;
1119
  };
1120
 
1121
  // mirror of llama_sampler_i:
 
1122
  LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
1123
  LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
1124
  LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
@@ -1148,7 +1229,7 @@ extern "C" {
1148
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1149
  /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1150
  DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
1151
- "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
1152
 
1153
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1154
  LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
@@ -1156,7 +1237,7 @@ extern "C" {
1156
  /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1157
  LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
1158
 
1159
- /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
1160
  LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
1161
 
1162
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
@@ -1171,6 +1252,9 @@ extern "C" {
1171
  /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
1172
  LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
1173
 
 
 
 
1174
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1175
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1176
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1194,22 +1278,38 @@ extern "C" {
1194
  float tau,
1195
  float eta);
1196
 
 
 
 
 
1197
  LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
1198
  const struct llama_vocab * vocab,
1199
  const char * grammar_str,
1200
  const char * grammar_root);
1201
 
1202
- /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639
1203
- /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
1204
- /// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
1205
- LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
1206
  const struct llama_vocab * vocab,
1207
  const char * grammar_str,
1208
  const char * grammar_root,
1209
  const char ** trigger_words,
1210
  size_t num_trigger_words,
1211
  const llama_token * trigger_tokens,
1212
- size_t num_trigger_tokens);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1213
 
1214
  /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
1215
  LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
 
60
  struct llama_model;
61
  struct llama_context;
62
  struct llama_sampler;
63
+ struct llama_kv_cache;
64
 
65
  typedef int32_t llama_pos;
66
  typedef int32_t llama_token;
 
106
  LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
107
  LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
108
  LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
109
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
110
+ LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
111
+ LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
112
+ LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
113
+ LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
114
+ LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
115
  };
116
 
117
  enum llama_rope_type {
 
220
  LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
221
  };
222
 
223
+ // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
224
  typedef struct llama_token_data {
225
  llama_token id; // token id
226
  float logit; // log-odds of the token
 
282
  };
283
  };
284
 
285
+ struct llama_model_tensor_buft_override {
286
+ const char * pattern;
287
+ ggml_backend_buffer_type_t buft;
288
+ };
289
+
290
  struct llama_model_params {
291
  // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
292
  ggml_backend_dev_t * devices;
293
 
294
+ // NULL-terminated list of buffer types to use for tensors that match a pattern
295
+ const struct llama_model_tensor_buft_override * tensor_buft_overrides;
296
+
297
  int32_t n_gpu_layers; // number of layers to store in VRAM
298
  enum llama_split_mode split_mode; // how to split the model across multiple GPUs
299
 
 
322
  };
323
 
324
  // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
325
+ // https://github.com/ggml-org/llama.cpp/pull/7544
326
  struct llama_context_params {
327
  uint32_t n_ctx; // text context, 0 = from model
328
  uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
 
335
  enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
336
  enum llama_attention_type attention_type; // attention type to use for embeddings
337
 
338
+ // ref: https://github.com/ggml-org/llama.cpp/pull/2054
339
  float rope_freq_base; // RoPE base frequency, 0 = from model
340
  float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
341
  float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
 
368
 
369
  // model quantization parameters
370
  typedef struct llama_model_quantize_params {
371
+ int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
372
+ enum llama_ftype ftype; // quantize to this llama_ftype
373
+ enum ggml_type output_tensor_type; // output tensor type
374
+ enum ggml_type token_embedding_type; // token embeddings tensor type
375
+ bool allow_requantize; // allow quantizing non-f32/f16 tensors
376
+ bool quantize_output_tensor; // quantize output.weight
377
+ bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
378
+ bool pure; // quantize all tensors to the default type
379
+ bool keep_split; // quantize to the same number of shards
380
+ void * imatrix; // pointer to importance matrix data
381
+ void * kv_overrides; // pointer to vector containing overrides
382
+ void * tensor_types; // pointer to vector containing tensor types
383
  } llama_model_quantize_params;
384
 
385
  typedef struct llama_logit_bias {
 
401
  struct llama_adapter_lora;
402
 
403
  // Helpers for getting default parameters
404
+ // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172)
405
  LLAMA_API struct llama_model_params llama_model_default_params(void);
406
  LLAMA_API struct llama_context_params llama_context_default_params(void);
407
  LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
 
484
  DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
485
 
486
  LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
487
+ LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
488
+ LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
489
 
490
  LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
491
  LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
 
494
  LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
495
  LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
496
  LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
497
+ LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
498
 
499
  // Get the model's RoPE frequency scaling factor
500
  LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
 
602
  // KV cache
603
  //
604
 
605
+ // TODO: start using struct llama_kv_cache
606
 
607
  // Information associated with an individual cell in the KV cache view.
608
  struct llama_kv_cache_view_cell {
 
657
 
658
  // Returns the number of tokens in the KV cache (slow, use only for debug)
659
  // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
660
+ LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
661
+
662
+ DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
663
+ "use llama_kv_self_n_tokens instead");
664
 
665
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
666
+ LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
667
+
668
+ DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
669
+ "use llama_kv_self_used_cells instead");
670
 
671
  // Clear the KV cache - both cell info is erased and KV data is zeroed
672
+ LLAMA_API void llama_kv_self_clear(
673
  struct llama_context * ctx);
674
 
675
  // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
 
677
  // seq_id < 0 : match any sequence
678
  // p0 < 0 : [0, p1]
679
  // p1 < 0 : [p0, inf)
680
+ LLAMA_API bool llama_kv_self_seq_rm(
681
  struct llama_context * ctx,
682
  llama_seq_id seq_id,
683
  llama_pos p0,
 
687
  // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
688
  // p0 < 0 : [0, p1]
689
  // p1 < 0 : [p0, inf)
690
+ LLAMA_API void llama_kv_self_seq_cp(
691
  struct llama_context * ctx,
692
  llama_seq_id seq_id_src,
693
  llama_seq_id seq_id_dst,
 
695
  llama_pos p1);
696
 
697
  // Removes all tokens that do not belong to the specified sequence
698
+ LLAMA_API void llama_kv_self_seq_keep(
699
  struct llama_context * ctx,
700
  llama_seq_id seq_id);
701
 
702
  // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
703
  // If the KV cache is RoPEd, the KV data is updated accordingly:
704
  // - lazily on next llama_decode()
705
+ // - explicitly with llama_kv_self_update()
706
  // p0 < 0 : [0, p1]
707
  // p1 < 0 : [p0, inf)
708
+ LLAMA_API void llama_kv_self_seq_add(
709
  struct llama_context * ctx,
710
  llama_seq_id seq_id,
711
  llama_pos p0,
 
715
  // Integer division of the positions by factor of `d > 1`
716
  // If the KV cache is RoPEd, the KV data is updated accordingly:
717
  // - lazily on next llama_decode()
718
+ // - explicitly with llama_kv_self_update()
719
  // p0 < 0 : [0, p1]
720
  // p1 < 0 : [p0, inf)
721
+ LLAMA_API void llama_kv_self_seq_div(
722
  struct llama_context * ctx,
723
  llama_seq_id seq_id,
724
  llama_pos p0,
 
726
  int d);
727
 
728
  // Returns the largest position present in the KV cache for the specified sequence
729
+ LLAMA_API llama_pos llama_kv_self_seq_pos_max(
730
  struct llama_context * ctx,
731
+ llama_seq_id seq_id);
 
 
 
732
 
733
  // Defragment the KV cache
734
  // This will be applied:
735
  // - lazily on next llama_decode()
736
+ // - explicitly with llama_kv_self_update()
737
+ LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
738
+
739
+ // Check if the context supports KV cache shifting
740
+ LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
741
 
742
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
743
+ LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
744
+
745
+ DEPRECATED(LLAMA_API void llama_kv_cache_clear(
746
+ struct llama_context * ctx),
747
+ "use llama_kv_self_clear instead");
748
+
749
+ DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
750
+ struct llama_context * ctx,
751
+ llama_seq_id seq_id,
752
+ llama_pos p0,
753
+ llama_pos p1),
754
+ "use llama_kv_self_seq_rm instead");
755
+
756
+ DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
757
+ struct llama_context * ctx,
758
+ llama_seq_id seq_id_src,
759
+ llama_seq_id seq_id_dst,
760
+ llama_pos p0,
761
+ llama_pos p1),
762
+ "use llama_kv_self_seq_cp instead");
763
+
764
+ DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
765
+ struct llama_context * ctx,
766
+ llama_seq_id seq_id),
767
+ "use llama_kv_self_seq_keep instead");
768
+
769
+ DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
770
+ struct llama_context * ctx,
771
+ llama_seq_id seq_id,
772
+ llama_pos p0,
773
+ llama_pos p1,
774
+ llama_pos delta),
775
+ "use llama_kv_self_seq_add instead");
776
+
777
+ DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
778
+ struct llama_context * ctx,
779
+ llama_seq_id seq_id,
780
+ llama_pos p0,
781
+ llama_pos p1,
782
+ int d),
783
+ "use llama_kv_self_seq_div instead");
784
+
785
+ DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
786
+ struct llama_context * ctx,
787
+ llama_seq_id seq_id),
788
+ "use llama_kv_self_seq_pos_max instead");
789
+
790
+ DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
791
+ "use llama_kv_self_defrag instead");
792
+
793
+ DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
794
+ "use llama_kv_self_can_shift instead");
795
+
796
+ DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
797
+ "use llama_kv_self_update instead");
798
 
 
 
799
 
800
  //
801
  // State / sessions
 
959
  // If set to true, the model will only attend to the past tokens
960
  LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
961
 
962
+ // Set whether the model is in warmup mode or not
963
+ // If true, all model tensors are activated during llama_decode() to load and cache their weights.
964
+ LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
965
+
966
  // Set abort callback
967
  LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
968
 
 
1120
 
1121
  /// Apply chat template. Inspired by hf apply_chat_template() on python.
1122
  /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
1123
+ /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
1124
  /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
1125
  /// @param chat Pointer to a list of multiple llama_chat_message
1126
  /// @param n_msg Number of llama_chat_message in this chat
 
1194
  };
1195
 
1196
  struct llama_sampler {
1197
+ const struct llama_sampler_i * iface;
1198
+ llama_sampler_context_t ctx;
1199
  };
1200
 
1201
  // mirror of llama_sampler_i:
1202
+ LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
1203
  LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
1204
  LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
1205
  LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
 
1229
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1230
  /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1231
  DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
1232
+ "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
1233
 
1234
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1235
  LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
 
1237
  /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1238
  LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
1239
 
1240
+ /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841
1241
  LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
1242
 
1243
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
 
1252
  /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
1253
  LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
1254
 
1255
+ /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
1256
+ LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
1257
+
1258
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
1259
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
1260
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
 
1278
  float tau,
1279
  float eta);
1280
 
1281
+ /// @details Intializes a GBNF grammar, see grammars/README.md for details.
1282
+ /// @param vocab The vocabulary that this grammar will be used with.
1283
+ /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
1284
+ /// @param grammar_root The name of the start symbol for the grammar.
1285
  LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
1286
  const struct llama_vocab * vocab,
1287
  const char * grammar_str,
1288
  const char * grammar_root);
1289
 
1290
+ DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
 
 
 
1291
  const struct llama_vocab * vocab,
1292
  const char * grammar_str,
1293
  const char * grammar_root,
1294
  const char ** trigger_words,
1295
  size_t num_trigger_words,
1296
  const llama_token * trigger_tokens,
1297
+ size_t num_trigger_tokens),
1298
+ "use llama_sampler_init_grammar_lazy_patterns instead");
1299
+
1300
+
1301
+ /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
1302
+ /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
1303
+ /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
1304
+ LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
1305
+ const struct llama_vocab * vocab,
1306
+ const char * grammar_str,
1307
+ const char * grammar_root,
1308
+ const char ** trigger_patterns,
1309
+ size_t num_trigger_patterns,
1310
+ const llama_token * trigger_tokens,
1311
+ size_t num_trigger_tokens);
1312
+
1313
 
1314
  /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
1315
  LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
examples/talk-llama/unicode.cpp CHANGED
@@ -618,7 +618,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
618
  result.reserve(utf8.size());
619
  size_t offset = 0;
620
  while (offset < utf8.size()) {
621
- result.push_back(unicode_cpt_from_utf8(utf8, offset));
 
 
 
 
 
 
 
622
  }
623
  return result;
624
  }
@@ -701,7 +708,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
701
  const auto cpts = unicode_cpts_from_utf8(text);
702
 
703
  // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
704
- // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
705
  std::string text_collapsed;
706
  if (need_collapse) {
707
  // collapse all unicode categories
 
618
  result.reserve(utf8.size());
619
  size_t offset = 0;
620
  while (offset < utf8.size()) {
621
+ try {
622
+ result.push_back(unicode_cpt_from_utf8(utf8, offset));
623
+ }
624
+ catch (const std::invalid_argument & /*ex*/) {
625
+ // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
626
+ ++offset;
627
+ result.emplace_back(0xFFFD); // replacement character
628
+ }
629
  }
630
  return result;
631
  }
 
708
  const auto cpts = unicode_cpts_from_utf8(text);
709
 
710
  // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
711
+ // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
712
  std::string text_collapsed;
713
  if (need_collapse) {
714
  // collapse all unicode categories