Spaces:
Sleeping
Sleeping
talk-llama : sync llama.cpp (#3084)
Browse files- examples/talk-llama/CMakeLists.txt +3 -0
- examples/talk-llama/llama-adapter.cpp +55 -20
- examples/talk-llama/llama-adapter.h +11 -9
- examples/talk-llama/llama-arch.cpp +264 -33
- examples/talk-llama/llama-arch.h +33 -0
- examples/talk-llama/llama-batch.h +2 -2
- examples/talk-llama/llama-chat.cpp +76 -2
- examples/talk-llama/llama-chat.h +4 -0
- examples/talk-llama/llama-context.cpp +0 -0
- examples/talk-llama/llama-context.h +214 -77
- examples/talk-llama/llama-cparams.h +1 -0
- examples/talk-llama/llama-grammar.cpp +183 -183
- examples/talk-llama/llama-grammar.h +13 -4
- examples/talk-llama/llama-graph.cpp +1706 -0
- examples/talk-llama/llama-graph.h +596 -0
- examples/talk-llama/llama-hparams.cpp +8 -0
- examples/talk-llama/llama-hparams.h +21 -0
- examples/talk-llama/llama-impl.h +6 -6
- examples/talk-llama/llama-io.cpp +15 -0
- examples/talk-llama/llama-io.h +35 -0
- examples/talk-llama/llama-kv-cache.cpp +965 -303
- examples/talk-llama/llama-kv-cache.h +145 -150
- examples/talk-llama/llama-memory.cpp +1 -0
- examples/talk-llama/llama-memory.h +21 -0
- examples/talk-llama/llama-mmap.cpp +11 -1
- examples/talk-llama/llama-mmap.h +1 -0
- examples/talk-llama/llama-model-loader.cpp +10 -5
- examples/talk-llama/llama-model-loader.h +5 -3
- examples/talk-llama/llama-model.cpp +0 -0
- examples/talk-llama/llama-model.h +42 -1
- examples/talk-llama/llama-quant.cpp +39 -9
- examples/talk-llama/llama-sampling.cpp +179 -67
- examples/talk-llama/llama-vocab.cpp +55 -5
- examples/talk-llama/llama.cpp +0 -0
- examples/talk-llama/llama.h +147 -47
- examples/talk-llama/unicode.cpp +9 -2
examples/talk-llama/CMakeLists.txt
CHANGED
|
@@ -12,9 +12,12 @@ if (WHISPER_SDL2)
|
|
| 12 |
llama-context.cpp
|
| 13 |
llama-cparams.cpp
|
| 14 |
llama-grammar.cpp
|
|
|
|
| 15 |
llama-hparams.cpp
|
| 16 |
llama-impl.cpp
|
|
|
|
| 17 |
llama-kv-cache.cpp
|
|
|
|
| 18 |
llama-mmap.cpp
|
| 19 |
llama-model-loader.cpp
|
| 20 |
llama-model.cpp
|
|
|
|
| 12 |
llama-context.cpp
|
| 13 |
llama-cparams.cpp
|
| 14 |
llama-grammar.cpp
|
| 15 |
+
llama-graph.cpp
|
| 16 |
llama-hparams.cpp
|
| 17 |
llama-impl.cpp
|
| 18 |
+
llama-io.cpp
|
| 19 |
llama-kv-cache.cpp
|
| 20 |
+
llama-memory.cpp
|
| 21 |
llama-mmap.cpp
|
| 22 |
llama-model-loader.cpp
|
| 23 |
llama-model.cpp
|
examples/talk-llama/llama-adapter.cpp
CHANGED
|
@@ -4,14 +4,13 @@
|
|
| 4 |
#include "llama-mmap.h"
|
| 5 |
#include "llama-model.h"
|
| 6 |
|
| 7 |
-
#include <algorithm>
|
| 8 |
#include <map>
|
| 9 |
#include <cassert>
|
| 10 |
#include <stdexcept>
|
| 11 |
|
| 12 |
// vec
|
| 13 |
|
| 14 |
-
|
| 15 |
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
| 16 |
return nullptr;
|
| 17 |
}
|
|
@@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
|
| 19 |
return tensors[il];
|
| 20 |
}
|
| 21 |
|
| 22 |
-
|
| 23 |
ggml_tensor * layer_dir = tensor_for(il);
|
| 24 |
if (layer_dir != nullptr) {
|
| 25 |
cur = ggml_add(ctx, cur, layer_dir);
|
|
@@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
|
|
| 40 |
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 41 |
auto it = ctx_map.find(buft);
|
| 42 |
if (it == ctx_map.end()) {
|
| 43 |
-
|
| 44 |
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
| 45 |
/*.mem_buffer =*/ NULL,
|
| 46 |
/*.no_alloc =*/ true,
|
|
@@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
|
|
| 91 |
return true;
|
| 92 |
}
|
| 93 |
|
| 94 |
-
|
| 95 |
const llama_model & model,
|
| 96 |
const float * data,
|
| 97 |
size_t len,
|
|
@@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
|
|
| 104 |
// disable the current control vector (but leave allocated for later)
|
| 105 |
layer_start = -1;
|
| 106 |
layer_end = -1;
|
| 107 |
-
return
|
| 108 |
}
|
| 109 |
|
| 110 |
if (n_embd != (int) hparams.n_embd) {
|
| 111 |
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
| 112 |
-
return
|
| 113 |
}
|
| 114 |
|
| 115 |
if (tensors.empty()) {
|
| 116 |
if (!init(model)) {
|
| 117 |
-
return
|
| 118 |
}
|
| 119 |
}
|
| 120 |
|
|
@@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
|
|
| 130 |
}
|
| 131 |
}
|
| 132 |
|
| 133 |
-
return
|
| 134 |
}
|
| 135 |
|
| 136 |
// lora
|
| 137 |
|
| 138 |
-
llama_adapter_lora_weight * llama_adapter_lora::get_weight(
|
| 139 |
const std::string name(w->name);
|
| 140 |
|
| 141 |
const auto pos = ab_map.find(name);
|
|
@@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
|
|
| 146 |
return nullptr;
|
| 147 |
}
|
| 148 |
|
| 149 |
-
static void llama_adapter_lora_init_impl(
|
| 150 |
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
| 151 |
|
| 152 |
ggml_context * ctx_init;
|
| 153 |
-
|
| 154 |
/* .no_alloc = */ true,
|
| 155 |
/* .ctx = */ &ctx_init,
|
| 156 |
};
|
|
@@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|
| 201 |
auto it = ctx_map.find(buft);
|
| 202 |
if (it == ctx_map.end()) {
|
| 203 |
// add a new context
|
| 204 |
-
|
| 205 |
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
| 206 |
/*.mem_buffer =*/ NULL,
|
| 207 |
/*.no_alloc =*/ true,
|
|
@@ -248,6 +247,26 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|
| 248 |
}
|
| 249 |
}
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
// add tensors
|
| 252 |
for (auto & it : ab_map) {
|
| 253 |
const std::string & name = it.first;
|
|
@@ -264,7 +283,23 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|
| 264 |
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
| 265 |
}
|
| 266 |
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
// validate tensor shape
|
| 269 |
if (is_token_embd) {
|
| 270 |
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
|
@@ -281,8 +316,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|
| 281 |
}
|
| 282 |
|
| 283 |
// save tensor to adapter
|
| 284 |
-
|
| 285 |
-
|
| 286 |
ggml_set_name(tensor_a, w.a->name);
|
| 287 |
ggml_set_name(tensor_b, w.b->name);
|
| 288 |
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
|
|
@@ -308,7 +343,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|
| 308 |
{
|
| 309 |
llama_file gguf_file(path_lora, "rb");
|
| 310 |
std::vector<uint8_t> read_buf;
|
| 311 |
-
auto set_tensor = [&](
|
| 312 |
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
| 313 |
size_t size = ggml_nbytes(orig);
|
| 314 |
read_buf.resize(size);
|
|
@@ -327,8 +362,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|
| 327 |
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
| 328 |
}
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
try {
|
| 334 |
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
|
@@ -342,6 +377,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
|
|
| 342 |
return nullptr;
|
| 343 |
}
|
| 344 |
|
| 345 |
-
void llama_adapter_lora_free(
|
| 346 |
delete adapter;
|
| 347 |
}
|
|
|
|
| 4 |
#include "llama-mmap.h"
|
| 5 |
#include "llama-model.h"
|
| 6 |
|
|
|
|
| 7 |
#include <map>
|
| 8 |
#include <cassert>
|
| 9 |
#include <stdexcept>
|
| 10 |
|
| 11 |
// vec
|
| 12 |
|
| 13 |
+
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
| 14 |
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
| 15 |
return nullptr;
|
| 16 |
}
|
|
|
|
| 18 |
return tensors[il];
|
| 19 |
}
|
| 20 |
|
| 21 |
+
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
|
| 22 |
ggml_tensor * layer_dir = tensor_for(il);
|
| 23 |
if (layer_dir != nullptr) {
|
| 24 |
cur = ggml_add(ctx, cur, layer_dir);
|
|
|
|
| 39 |
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 40 |
auto it = ctx_map.find(buft);
|
| 41 |
if (it == ctx_map.end()) {
|
| 42 |
+
ggml_init_params params = {
|
| 43 |
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
| 44 |
/*.mem_buffer =*/ NULL,
|
| 45 |
/*.no_alloc =*/ true,
|
|
|
|
| 90 |
return true;
|
| 91 |
}
|
| 92 |
|
| 93 |
+
bool llama_adapter_cvec::apply(
|
| 94 |
const llama_model & model,
|
| 95 |
const float * data,
|
| 96 |
size_t len,
|
|
|
|
| 103 |
// disable the current control vector (but leave allocated for later)
|
| 104 |
layer_start = -1;
|
| 105 |
layer_end = -1;
|
| 106 |
+
return true;
|
| 107 |
}
|
| 108 |
|
| 109 |
if (n_embd != (int) hparams.n_embd) {
|
| 110 |
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
| 111 |
+
return false;
|
| 112 |
}
|
| 113 |
|
| 114 |
if (tensors.empty()) {
|
| 115 |
if (!init(model)) {
|
| 116 |
+
return false;
|
| 117 |
}
|
| 118 |
}
|
| 119 |
|
|
|
|
| 129 |
}
|
| 130 |
}
|
| 131 |
|
| 132 |
+
return true;
|
| 133 |
}
|
| 134 |
|
| 135 |
// lora
|
| 136 |
|
| 137 |
+
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
|
| 138 |
const std::string name(w->name);
|
| 139 |
|
| 140 |
const auto pos = ab_map.find(name);
|
|
|
|
| 145 |
return nullptr;
|
| 146 |
}
|
| 147 |
|
| 148 |
+
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
|
| 149 |
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
| 150 |
|
| 151 |
ggml_context * ctx_init;
|
| 152 |
+
gguf_init_params meta_gguf_params = {
|
| 153 |
/* .no_alloc = */ true,
|
| 154 |
/* .ctx = */ &ctx_init,
|
| 155 |
};
|
|
|
|
| 200 |
auto it = ctx_map.find(buft);
|
| 201 |
if (it == ctx_map.end()) {
|
| 202 |
// add a new context
|
| 203 |
+
ggml_init_params params = {
|
| 204 |
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
| 205 |
/*.mem_buffer =*/ NULL,
|
| 206 |
/*.no_alloc =*/ true,
|
|
|
|
| 247 |
}
|
| 248 |
}
|
| 249 |
|
| 250 |
+
// get extra buffer types of the CPU
|
| 251 |
+
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
|
| 252 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
|
| 253 |
+
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
| 254 |
+
{
|
| 255 |
+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
| 256 |
+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
| 257 |
+
|
| 258 |
+
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
| 259 |
+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
| 260 |
+
|
| 261 |
+
if (ggml_backend_dev_get_extra_bufts_fn) {
|
| 262 |
+
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
|
| 263 |
+
while (extra_bufts && *extra_bufts) {
|
| 264 |
+
buft_extra.emplace_back(*extra_bufts);
|
| 265 |
+
++extra_bufts;
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
// add tensors
|
| 271 |
for (auto & it : ab_map) {
|
| 272 |
const std::string & name = it.first;
|
|
|
|
| 283 |
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
| 284 |
}
|
| 285 |
|
| 286 |
+
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
|
| 287 |
+
|
| 288 |
+
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
|
| 289 |
+
for (auto & ex : buft_extra) {
|
| 290 |
+
if (ex == buft) {
|
| 291 |
+
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
| 292 |
+
|
| 293 |
+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
| 294 |
+
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
| 295 |
+
|
| 296 |
+
break;
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
| 301 |
+
|
| 302 |
+
ggml_context * dev_ctx = ctx_for_buft(buft);
|
| 303 |
// validate tensor shape
|
| 304 |
if (is_token_embd) {
|
| 305 |
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
|
|
|
| 316 |
}
|
| 317 |
|
| 318 |
// save tensor to adapter
|
| 319 |
+
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
| 320 |
+
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
| 321 |
ggml_set_name(tensor_a, w.a->name);
|
| 322 |
ggml_set_name(tensor_b, w.b->name);
|
| 323 |
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
|
|
|
|
| 343 |
{
|
| 344 |
llama_file gguf_file(path_lora, "rb");
|
| 345 |
std::vector<uint8_t> read_buf;
|
| 346 |
+
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
|
| 347 |
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
| 348 |
size_t size = ggml_nbytes(orig);
|
| 349 |
read_buf.resize(size);
|
|
|
|
| 362 |
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
| 363 |
}
|
| 364 |
|
| 365 |
+
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
|
| 366 |
+
llama_adapter_lora * adapter = new llama_adapter_lora();
|
| 367 |
|
| 368 |
try {
|
| 369 |
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
|
|
|
| 377 |
return nullptr;
|
| 378 |
}
|
| 379 |
|
| 380 |
+
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
| 381 |
delete adapter;
|
| 382 |
}
|
examples/talk-llama/llama-adapter.h
CHANGED
|
@@ -15,11 +15,11 @@
|
|
| 15 |
//
|
| 16 |
|
| 17 |
struct llama_adapter_cvec {
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
const llama_model & model,
|
| 24 |
const float * data,
|
| 25 |
size_t len,
|
|
@@ -36,7 +36,7 @@ private:
|
|
| 36 |
std::vector<ggml_context_ptr> ctxs;
|
| 37 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 38 |
|
| 39 |
-
std::vector<
|
| 40 |
};
|
| 41 |
|
| 42 |
//
|
|
@@ -44,8 +44,8 @@ private:
|
|
| 44 |
//
|
| 45 |
|
| 46 |
struct llama_adapter_lora_weight {
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
// get actual scale based on rank and alpha
|
| 51 |
float get_scale(float alpha, float adapter_scale) const {
|
|
@@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
|
|
| 55 |
}
|
| 56 |
|
| 57 |
llama_adapter_lora_weight() = default;
|
| 58 |
-
llama_adapter_lora_weight(
|
| 59 |
};
|
| 60 |
|
| 61 |
struct llama_adapter_lora {
|
| 62 |
// map tensor name to lora_a_b
|
| 63 |
-
std::unordered_map<std::string,
|
| 64 |
|
| 65 |
std::vector<ggml_context_ptr> ctxs;
|
| 66 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
@@ -70,5 +70,7 @@ struct llama_adapter_lora {
|
|
| 70 |
llama_adapter_lora() = default;
|
| 71 |
~llama_adapter_lora() = default;
|
| 72 |
|
| 73 |
-
llama_adapter_lora_weight * get_weight(
|
| 74 |
};
|
|
|
|
|
|
|
|
|
| 15 |
//
|
| 16 |
|
| 17 |
struct llama_adapter_cvec {
|
| 18 |
+
ggml_tensor * tensor_for(int il) const;
|
| 19 |
|
| 20 |
+
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
|
| 21 |
|
| 22 |
+
bool apply(
|
| 23 |
const llama_model & model,
|
| 24 |
const float * data,
|
| 25 |
size_t len,
|
|
|
|
| 36 |
std::vector<ggml_context_ptr> ctxs;
|
| 37 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 38 |
|
| 39 |
+
std::vector<ggml_tensor *> tensors; // per layer
|
| 40 |
};
|
| 41 |
|
| 42 |
//
|
|
|
|
| 44 |
//
|
| 45 |
|
| 46 |
struct llama_adapter_lora_weight {
|
| 47 |
+
ggml_tensor * a = nullptr;
|
| 48 |
+
ggml_tensor * b = nullptr;
|
| 49 |
|
| 50 |
// get actual scale based on rank and alpha
|
| 51 |
float get_scale(float alpha, float adapter_scale) const {
|
|
|
|
| 55 |
}
|
| 56 |
|
| 57 |
llama_adapter_lora_weight() = default;
|
| 58 |
+
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
|
| 59 |
};
|
| 60 |
|
| 61 |
struct llama_adapter_lora {
|
| 62 |
// map tensor name to lora_a_b
|
| 63 |
+
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
|
| 64 |
|
| 65 |
std::vector<ggml_context_ptr> ctxs;
|
| 66 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
|
|
| 70 |
llama_adapter_lora() = default;
|
| 71 |
~llama_adapter_lora() = default;
|
| 72 |
|
| 73 |
+
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
|
| 74 |
};
|
| 75 |
+
|
| 76 |
+
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
|
| 7 |
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
| 8 |
{ LLM_ARCH_LLAMA, "llama" },
|
|
|
|
| 9 |
{ LLM_ARCH_DECI, "deci" },
|
| 10 |
{ LLM_ARCH_FALCON, "falcon" },
|
| 11 |
{ LLM_ARCH_GROK, "grok" },
|
|
@@ -25,6 +26,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 25 |
{ LLM_ARCH_QWEN2, "qwen2" },
|
| 26 |
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
| 27 |
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
|
|
|
|
|
|
| 28 |
{ LLM_ARCH_PHI2, "phi2" },
|
| 29 |
{ LLM_ARCH_PHI3, "phi3" },
|
| 30 |
{ LLM_ARCH_PHIMOE, "phimoe" },
|
|
@@ -36,6 +39,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 36 |
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
| 37 |
{ LLM_ARCH_GEMMA, "gemma" },
|
| 38 |
{ LLM_ARCH_GEMMA2, "gemma2" },
|
|
|
|
| 39 |
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 40 |
{ LLM_ARCH_MAMBA, "mamba" },
|
| 41 |
{ LLM_ARCH_XVERSE, "xverse" },
|
|
@@ -50,6 +54,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 50 |
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
| 51 |
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
| 52 |
{ LLM_ARCH_CHATGLM, "chatglm" },
|
|
|
|
| 53 |
{ LLM_ARCH_BITNET, "bitnet" },
|
| 54 |
{ LLM_ARCH_T5, "t5" },
|
| 55 |
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
|
@@ -58,10 +63,14 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 58 |
{ LLM_ARCH_EXAONE, "exaone" },
|
| 59 |
{ LLM_ARCH_RWKV6, "rwkv6" },
|
| 60 |
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
|
|
|
|
|
|
| 61 |
{ LLM_ARCH_GRANITE, "granite" },
|
| 62 |
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
| 63 |
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
| 64 |
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
|
|
|
|
|
|
| 65 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 66 |
};
|
| 67 |
|
|
@@ -70,6 +79,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 70 |
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
| 71 |
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
| 72 |
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
|
|
|
| 73 |
{ LLM_KV_GENERAL_NAME, "general.name" },
|
| 74 |
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
| 75 |
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
|
@@ -108,23 +118,30 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 108 |
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
| 109 |
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
| 110 |
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
|
|
|
| 111 |
|
| 112 |
-
{ LLM_KV_ATTENTION_HEAD_COUNT,
|
| 113 |
-
{ LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
| 114 |
-
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
|
| 115 |
-
{ LLM_KV_ATTENTION_CLAMP_KQV,
|
| 116 |
-
{ LLM_KV_ATTENTION_KEY_LENGTH,
|
| 117 |
-
{ LLM_KV_ATTENTION_VALUE_LENGTH,
|
| 118 |
-
{ LLM_KV_ATTENTION_LAYERNORM_EPS,
|
| 119 |
-
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
| 120 |
-
{ LLM_KV_ATTENTION_GROUPNORM_EPS,
|
| 121 |
-
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS,
|
| 122 |
-
{ LLM_KV_ATTENTION_CAUSAL,
|
| 123 |
-
{ LLM_KV_ATTENTION_Q_LORA_RANK,
|
| 124 |
-
{ LLM_KV_ATTENTION_KV_LORA_RANK,
|
| 125 |
-
{
|
| 126 |
-
{
|
| 127 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 130 |
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
|
@@ -223,6 +240,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 223 |
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 224 |
},
|
| 225 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
{
|
| 227 |
LLM_ARCH_DECI,
|
| 228 |
{
|
|
@@ -554,6 +600,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 554 |
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 555 |
},
|
| 556 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
{
|
| 558 |
LLM_ARCH_PHI2,
|
| 559 |
{
|
|
@@ -766,6 +851,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 766 |
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 767 |
},
|
| 768 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
{
|
| 770 |
LLM_ARCH_STARCODER2,
|
| 771 |
{
|
|
@@ -999,6 +1105,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 999 |
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
| 1000 |
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
| 1001 |
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
|
|
|
|
|
|
| 1002 |
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1003 |
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1004 |
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
@@ -1015,6 +1123,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1015 |
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 1016 |
},
|
| 1017 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1018 |
{
|
| 1019 |
LLM_ARCH_CHATGLM,
|
| 1020 |
{
|
|
@@ -1033,6 +1157,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1033 |
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1034 |
},
|
| 1035 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1036 |
{
|
| 1037 |
LLM_ARCH_BITNET,
|
| 1038 |
{
|
|
@@ -1217,6 +1360,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1217 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1218 |
},
|
| 1219 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1220 |
{
|
| 1221 |
LLM_ARCH_GRANITE,
|
| 1222 |
{
|
|
@@ -1296,6 +1507,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1296 |
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
| 1297 |
},
|
| 1298 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1299 |
{
|
| 1300 |
LLM_ARCH_UNKNOWN,
|
| 1301 |
{
|
|
@@ -1333,23 +1567,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1333 |
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1334 |
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1335 |
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1336 |
-
{
|
| 1337 |
-
{
|
| 1338 |
-
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1339 |
-
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1340 |
-
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1341 |
-
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1342 |
-
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1343 |
-
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1344 |
-
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1345 |
-
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1346 |
-
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1347 |
-
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1348 |
-
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1349 |
-
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1350 |
-
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1351 |
-
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1352 |
-
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1353 |
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1354 |
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1355 |
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
@@ -1376,6 +1595,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1376 |
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1377 |
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1378 |
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1379 |
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1380 |
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1381 |
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
@@ -1394,6 +1619,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1394 |
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1395 |
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1396 |
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
|
|
|
|
|
|
|
|
|
| 1397 |
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1398 |
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1399 |
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
|
@@ -1401,6 +1629,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1401 |
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1402 |
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1403 |
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
|
|
|
|
|
|
|
|
|
| 1404 |
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
| 1405 |
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1406 |
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
|
|
|
| 6 |
|
| 7 |
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
| 8 |
{ LLM_ARCH_LLAMA, "llama" },
|
| 9 |
+
{ LLM_ARCH_LLAMA4, "llama4" },
|
| 10 |
{ LLM_ARCH_DECI, "deci" },
|
| 11 |
{ LLM_ARCH_FALCON, "falcon" },
|
| 12 |
{ LLM_ARCH_GROK, "grok" },
|
|
|
|
| 26 |
{ LLM_ARCH_QWEN2, "qwen2" },
|
| 27 |
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
| 28 |
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
| 29 |
+
{ LLM_ARCH_QWEN3, "qwen3" },
|
| 30 |
+
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
| 31 |
{ LLM_ARCH_PHI2, "phi2" },
|
| 32 |
{ LLM_ARCH_PHI3, "phi3" },
|
| 33 |
{ LLM_ARCH_PHIMOE, "phimoe" },
|
|
|
|
| 39 |
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
| 40 |
{ LLM_ARCH_GEMMA, "gemma" },
|
| 41 |
{ LLM_ARCH_GEMMA2, "gemma2" },
|
| 42 |
+
{ LLM_ARCH_GEMMA3, "gemma3" },
|
| 43 |
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 44 |
{ LLM_ARCH_MAMBA, "mamba" },
|
| 45 |
{ LLM_ARCH_XVERSE, "xverse" },
|
|
|
|
| 54 |
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
| 55 |
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
| 56 |
{ LLM_ARCH_CHATGLM, "chatglm" },
|
| 57 |
+
{ LLM_ARCH_GLM4, "glm4" },
|
| 58 |
{ LLM_ARCH_BITNET, "bitnet" },
|
| 59 |
{ LLM_ARCH_T5, "t5" },
|
| 60 |
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
|
|
|
| 63 |
{ LLM_ARCH_EXAONE, "exaone" },
|
| 64 |
{ LLM_ARCH_RWKV6, "rwkv6" },
|
| 65 |
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
| 66 |
+
{ LLM_ARCH_RWKV7, "rwkv7" },
|
| 67 |
+
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
| 68 |
{ LLM_ARCH_GRANITE, "granite" },
|
| 69 |
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
| 70 |
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
| 71 |
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
| 72 |
+
{ LLM_ARCH_PLM, "plm" },
|
| 73 |
+
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
| 74 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 75 |
};
|
| 76 |
|
|
|
|
| 79 |
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
| 80 |
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
| 81 |
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
| 82 |
+
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
|
| 83 |
{ LLM_KV_GENERAL_NAME, "general.name" },
|
| 84 |
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
| 85 |
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
|
|
|
| 118 |
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
| 119 |
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
| 120 |
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
| 121 |
+
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
|
| 122 |
|
| 123 |
+
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
| 124 |
+
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
| 125 |
+
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
| 126 |
+
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
| 127 |
+
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
| 128 |
+
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
| 129 |
+
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
| 130 |
+
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
| 131 |
+
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
| 132 |
+
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
| 133 |
+
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
| 134 |
+
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
| 135 |
+
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
| 136 |
+
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
| 137 |
+
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
| 138 |
+
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
| 139 |
+
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
| 140 |
+
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
| 141 |
+
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
| 142 |
+
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
| 143 |
+
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
| 144 |
+
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
| 145 |
|
| 146 |
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 147 |
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
|
|
|
| 240 |
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 241 |
},
|
| 242 |
},
|
| 243 |
+
{
|
| 244 |
+
LLM_ARCH_LLAMA4,
|
| 245 |
+
{
|
| 246 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 247 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 248 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 249 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 250 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 251 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 252 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 253 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 254 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 255 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 256 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 257 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 258 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 259 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 260 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 261 |
+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
| 262 |
+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
| 263 |
+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
| 264 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 265 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 266 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 267 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 268 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 269 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 270 |
+
},
|
| 271 |
+
},
|
| 272 |
{
|
| 273 |
LLM_ARCH_DECI,
|
| 274 |
{
|
|
|
|
| 600 |
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 601 |
},
|
| 602 |
},
|
| 603 |
+
{
|
| 604 |
+
LLM_ARCH_QWEN3,
|
| 605 |
+
{
|
| 606 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 607 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 608 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 609 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 610 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 611 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 612 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 613 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 614 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 615 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 616 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 617 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 618 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 619 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 620 |
+
},
|
| 621 |
+
},
|
| 622 |
+
{
|
| 623 |
+
LLM_ARCH_QWEN3MOE,
|
| 624 |
+
{
|
| 625 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 626 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 627 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 628 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 629 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 630 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 631 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 632 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 633 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 634 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 635 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 636 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 637 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 638 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 639 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 640 |
+
},
|
| 641 |
+
},
|
| 642 |
{
|
| 643 |
LLM_ARCH_PHI2,
|
| 644 |
{
|
|
|
|
| 851 |
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 852 |
},
|
| 853 |
},
|
| 854 |
+
{
|
| 855 |
+
LLM_ARCH_GEMMA3,
|
| 856 |
+
{
|
| 857 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 858 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 859 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 860 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 861 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 862 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 863 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 864 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 865 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 866 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 867 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 868 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 869 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 870 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 871 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 872 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 873 |
+
},
|
| 874 |
+
},
|
| 875 |
{
|
| 876 |
LLM_ARCH_STARCODER2,
|
| 877 |
{
|
|
|
|
| 1105 |
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
| 1106 |
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
| 1107 |
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
| 1108 |
+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
| 1109 |
+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
| 1110 |
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1111 |
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1112 |
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
|
|
| 1123 |
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 1124 |
},
|
| 1125 |
},
|
| 1126 |
+
{
|
| 1127 |
+
LLM_ARCH_PLM,
|
| 1128 |
+
{
|
| 1129 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1130 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1131 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1132 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1133 |
+
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
| 1134 |
+
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
| 1135 |
+
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
| 1136 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1137 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1138 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1139 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1140 |
+
},
|
| 1141 |
+
},
|
| 1142 |
{
|
| 1143 |
LLM_ARCH_CHATGLM,
|
| 1144 |
{
|
|
|
|
| 1157 |
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1158 |
},
|
| 1159 |
},
|
| 1160 |
+
{
|
| 1161 |
+
LLM_ARCH_GLM4,
|
| 1162 |
+
{
|
| 1163 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1164 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 1165 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1166 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1167 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1168 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1169 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1170 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1171 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1172 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1173 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1174 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1175 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 1176 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 1177 |
+
},
|
| 1178 |
+
},
|
| 1179 |
{
|
| 1180 |
LLM_ARCH_BITNET,
|
| 1181 |
{
|
|
|
|
| 1360 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1361 |
},
|
| 1362 |
},
|
| 1363 |
+
{
|
| 1364 |
+
LLM_ARCH_RWKV7,
|
| 1365 |
+
{
|
| 1366 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1367 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1368 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1369 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1370 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1371 |
+
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
| 1372 |
+
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
| 1373 |
+
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
| 1374 |
+
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
| 1375 |
+
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
| 1376 |
+
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
| 1377 |
+
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
| 1378 |
+
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
| 1379 |
+
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
| 1380 |
+
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
| 1381 |
+
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
| 1382 |
+
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
| 1383 |
+
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
| 1384 |
+
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
| 1385 |
+
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
| 1386 |
+
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
| 1387 |
+
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
| 1388 |
+
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
| 1389 |
+
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
| 1390 |
+
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
| 1391 |
+
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
| 1392 |
+
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
| 1393 |
+
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
| 1394 |
+
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
| 1395 |
+
},
|
| 1396 |
+
},
|
| 1397 |
+
{
|
| 1398 |
+
LLM_ARCH_ARWKV7,
|
| 1399 |
+
{
|
| 1400 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1401 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1402 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1403 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1404 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1405 |
+
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
| 1406 |
+
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
| 1407 |
+
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
| 1408 |
+
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
| 1409 |
+
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
| 1410 |
+
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
| 1411 |
+
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
| 1412 |
+
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
| 1413 |
+
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
| 1414 |
+
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
| 1415 |
+
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
| 1416 |
+
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
| 1417 |
+
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
| 1418 |
+
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
| 1419 |
+
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
| 1420 |
+
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
| 1421 |
+
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
| 1422 |
+
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
| 1423 |
+
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
| 1424 |
+
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
| 1425 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1426 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1427 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1428 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1429 |
+
},
|
| 1430 |
+
},
|
| 1431 |
{
|
| 1432 |
LLM_ARCH_GRANITE,
|
| 1433 |
{
|
|
|
|
| 1507 |
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
| 1508 |
},
|
| 1509 |
},
|
| 1510 |
+
{
|
| 1511 |
+
LLM_ARCH_BAILINGMOE,
|
| 1512 |
+
{
|
| 1513 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1514 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1515 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1516 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 1517 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1518 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1519 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1520 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1521 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1522 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1523 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1524 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1525 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1526 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1527 |
+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
| 1528 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 1529 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 1530 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1531 |
+
},
|
| 1532 |
+
},
|
| 1533 |
{
|
| 1534 |
LLM_ARCH_UNKNOWN,
|
| 1535 |
{
|
|
|
|
| 1567 |
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1568 |
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1569 |
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1570 |
+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1571 |
+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1572 |
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1573 |
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1574 |
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
|
|
| 1595 |
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1596 |
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1597 |
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1598 |
+
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1599 |
+
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1600 |
+
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1601 |
+
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1602 |
+
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1603 |
+
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1604 |
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1605 |
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1606 |
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
|
|
| 1619 |
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1620 |
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1621 |
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1622 |
+
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1623 |
+
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1624 |
+
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1625 |
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1626 |
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1627 |
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
|
|
|
| 1629 |
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1630 |
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1631 |
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1632 |
+
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1633 |
+
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1634 |
+
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1635 |
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
| 1636 |
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1637 |
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -10,6 +10,7 @@
|
|
| 10 |
|
| 11 |
enum llm_arch {
|
| 12 |
LLM_ARCH_LLAMA,
|
|
|
|
| 13 |
LLM_ARCH_DECI,
|
| 14 |
LLM_ARCH_FALCON,
|
| 15 |
LLM_ARCH_BAICHUAN,
|
|
@@ -29,6 +30,8 @@ enum llm_arch {
|
|
| 29 |
LLM_ARCH_QWEN2,
|
| 30 |
LLM_ARCH_QWEN2MOE,
|
| 31 |
LLM_ARCH_QWEN2VL,
|
|
|
|
|
|
|
| 32 |
LLM_ARCH_PHI2,
|
| 33 |
LLM_ARCH_PHI3,
|
| 34 |
LLM_ARCH_PHIMOE,
|
|
@@ -40,6 +43,7 @@ enum llm_arch {
|
|
| 40 |
LLM_ARCH_MINICPM3,
|
| 41 |
LLM_ARCH_GEMMA,
|
| 42 |
LLM_ARCH_GEMMA2,
|
|
|
|
| 43 |
LLM_ARCH_STARCODER2,
|
| 44 |
LLM_ARCH_MAMBA,
|
| 45 |
LLM_ARCH_XVERSE,
|
|
@@ -54,6 +58,7 @@ enum llm_arch {
|
|
| 54 |
LLM_ARCH_DEEPSEEK,
|
| 55 |
LLM_ARCH_DEEPSEEK2,
|
| 56 |
LLM_ARCH_CHATGLM,
|
|
|
|
| 57 |
LLM_ARCH_BITNET,
|
| 58 |
LLM_ARCH_T5,
|
| 59 |
LLM_ARCH_T5ENCODER,
|
|
@@ -62,10 +67,14 @@ enum llm_arch {
|
|
| 62 |
LLM_ARCH_EXAONE,
|
| 63 |
LLM_ARCH_RWKV6,
|
| 64 |
LLM_ARCH_RWKV6QWEN2,
|
|
|
|
|
|
|
| 65 |
LLM_ARCH_GRANITE,
|
| 66 |
LLM_ARCH_GRANITE_MOE,
|
| 67 |
LLM_ARCH_CHAMELEON,
|
| 68 |
LLM_ARCH_WAVTOKENIZER_DEC,
|
|
|
|
|
|
|
| 69 |
LLM_ARCH_UNKNOWN,
|
| 70 |
};
|
| 71 |
|
|
@@ -74,6 +83,7 @@ enum llm_kv {
|
|
| 74 |
LLM_KV_GENERAL_ARCHITECTURE,
|
| 75 |
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
| 76 |
LLM_KV_GENERAL_ALIGNMENT,
|
|
|
|
| 77 |
LLM_KV_GENERAL_NAME,
|
| 78 |
LLM_KV_GENERAL_AUTHOR,
|
| 79 |
LLM_KV_GENERAL_VERSION,
|
|
@@ -112,6 +122,7 @@ enum llm_kv {
|
|
| 112 |
LLM_KV_RESIDUAL_SCALE,
|
| 113 |
LLM_KV_EMBEDDING_SCALE,
|
| 114 |
LLM_KV_TOKEN_SHIFT_COUNT,
|
|
|
|
| 115 |
|
| 116 |
LLM_KV_ATTENTION_HEAD_COUNT,
|
| 117 |
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
|
@@ -126,9 +137,15 @@ enum llm_kv {
|
|
| 126 |
LLM_KV_ATTENTION_CAUSAL,
|
| 127 |
LLM_KV_ATTENTION_Q_LORA_RANK,
|
| 128 |
LLM_KV_ATTENTION_KV_LORA_RANK,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
| 130 |
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
| 131 |
LLM_KV_ATTENTION_SCALE,
|
|
|
|
|
|
|
| 132 |
|
| 133 |
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 134 |
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
|
@@ -242,6 +259,8 @@ enum llm_tensor {
|
|
| 242 |
LLM_TENSOR_ATTN_Q_NORM,
|
| 243 |
LLM_TENSOR_ATTN_K_NORM,
|
| 244 |
LLM_TENSOR_LAYER_OUT_NORM,
|
|
|
|
|
|
|
| 245 |
LLM_TENSOR_SSM_IN,
|
| 246 |
LLM_TENSOR_SSM_CONV1D,
|
| 247 |
LLM_TENSOR_SSM_X,
|
|
@@ -249,8 +268,20 @@ enum llm_tensor {
|
|
| 249 |
LLM_TENSOR_SSM_A,
|
| 250 |
LLM_TENSOR_SSM_D,
|
| 251 |
LLM_TENSOR_SSM_OUT,
|
|
|
|
| 252 |
LLM_TENSOR_TIME_MIX_W1,
|
| 253 |
LLM_TENSOR_TIME_MIX_W2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
LLM_TENSOR_TIME_MIX_LERP_X,
|
| 255 |
LLM_TENSOR_TIME_MIX_LERP_W,
|
| 256 |
LLM_TENSOR_TIME_MIX_LERP_K,
|
|
@@ -277,6 +308,8 @@ enum llm_tensor {
|
|
| 277 |
LLM_TENSOR_ATTN_Q_B,
|
| 278 |
LLM_TENSOR_ATTN_KV_A_MQA,
|
| 279 |
LLM_TENSOR_ATTN_KV_B,
|
|
|
|
|
|
|
| 280 |
LLM_TENSOR_ATTN_Q_A_NORM,
|
| 281 |
LLM_TENSOR_ATTN_KV_A_NORM,
|
| 282 |
LLM_TENSOR_ATTN_SUB_NORM,
|
|
|
|
| 10 |
|
| 11 |
enum llm_arch {
|
| 12 |
LLM_ARCH_LLAMA,
|
| 13 |
+
LLM_ARCH_LLAMA4,
|
| 14 |
LLM_ARCH_DECI,
|
| 15 |
LLM_ARCH_FALCON,
|
| 16 |
LLM_ARCH_BAICHUAN,
|
|
|
|
| 30 |
LLM_ARCH_QWEN2,
|
| 31 |
LLM_ARCH_QWEN2MOE,
|
| 32 |
LLM_ARCH_QWEN2VL,
|
| 33 |
+
LLM_ARCH_QWEN3,
|
| 34 |
+
LLM_ARCH_QWEN3MOE,
|
| 35 |
LLM_ARCH_PHI2,
|
| 36 |
LLM_ARCH_PHI3,
|
| 37 |
LLM_ARCH_PHIMOE,
|
|
|
|
| 43 |
LLM_ARCH_MINICPM3,
|
| 44 |
LLM_ARCH_GEMMA,
|
| 45 |
LLM_ARCH_GEMMA2,
|
| 46 |
+
LLM_ARCH_GEMMA3,
|
| 47 |
LLM_ARCH_STARCODER2,
|
| 48 |
LLM_ARCH_MAMBA,
|
| 49 |
LLM_ARCH_XVERSE,
|
|
|
|
| 58 |
LLM_ARCH_DEEPSEEK,
|
| 59 |
LLM_ARCH_DEEPSEEK2,
|
| 60 |
LLM_ARCH_CHATGLM,
|
| 61 |
+
LLM_ARCH_GLM4,
|
| 62 |
LLM_ARCH_BITNET,
|
| 63 |
LLM_ARCH_T5,
|
| 64 |
LLM_ARCH_T5ENCODER,
|
|
|
|
| 67 |
LLM_ARCH_EXAONE,
|
| 68 |
LLM_ARCH_RWKV6,
|
| 69 |
LLM_ARCH_RWKV6QWEN2,
|
| 70 |
+
LLM_ARCH_RWKV7,
|
| 71 |
+
LLM_ARCH_ARWKV7,
|
| 72 |
LLM_ARCH_GRANITE,
|
| 73 |
LLM_ARCH_GRANITE_MOE,
|
| 74 |
LLM_ARCH_CHAMELEON,
|
| 75 |
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 76 |
+
LLM_ARCH_PLM,
|
| 77 |
+
LLM_ARCH_BAILINGMOE,
|
| 78 |
LLM_ARCH_UNKNOWN,
|
| 79 |
};
|
| 80 |
|
|
|
|
| 83 |
LLM_KV_GENERAL_ARCHITECTURE,
|
| 84 |
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
| 85 |
LLM_KV_GENERAL_ALIGNMENT,
|
| 86 |
+
LLM_KV_GENERAL_FILE_TYPE,
|
| 87 |
LLM_KV_GENERAL_NAME,
|
| 88 |
LLM_KV_GENERAL_AUTHOR,
|
| 89 |
LLM_KV_GENERAL_VERSION,
|
|
|
|
| 122 |
LLM_KV_RESIDUAL_SCALE,
|
| 123 |
LLM_KV_EMBEDDING_SCALE,
|
| 124 |
LLM_KV_TOKEN_SHIFT_COUNT,
|
| 125 |
+
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
|
| 126 |
|
| 127 |
LLM_KV_ATTENTION_HEAD_COUNT,
|
| 128 |
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
|
|
|
| 137 |
LLM_KV_ATTENTION_CAUSAL,
|
| 138 |
LLM_KV_ATTENTION_Q_LORA_RANK,
|
| 139 |
LLM_KV_ATTENTION_KV_LORA_RANK,
|
| 140 |
+
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
| 141 |
+
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
| 142 |
+
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
| 143 |
+
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
| 144 |
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
| 145 |
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
| 146 |
LLM_KV_ATTENTION_SCALE,
|
| 147 |
+
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
| 148 |
+
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
| 149 |
|
| 150 |
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 151 |
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
|
|
|
| 259 |
LLM_TENSOR_ATTN_Q_NORM,
|
| 260 |
LLM_TENSOR_ATTN_K_NORM,
|
| 261 |
LLM_TENSOR_LAYER_OUT_NORM,
|
| 262 |
+
LLM_TENSOR_POST_ATTN_NORM,
|
| 263 |
+
LLM_TENSOR_POST_MLP_NORM,
|
| 264 |
LLM_TENSOR_SSM_IN,
|
| 265 |
LLM_TENSOR_SSM_CONV1D,
|
| 266 |
LLM_TENSOR_SSM_X,
|
|
|
|
| 268 |
LLM_TENSOR_SSM_A,
|
| 269 |
LLM_TENSOR_SSM_D,
|
| 270 |
LLM_TENSOR_SSM_OUT,
|
| 271 |
+
LLM_TENSOR_TIME_MIX_W0,
|
| 272 |
LLM_TENSOR_TIME_MIX_W1,
|
| 273 |
LLM_TENSOR_TIME_MIX_W2,
|
| 274 |
+
LLM_TENSOR_TIME_MIX_A0,
|
| 275 |
+
LLM_TENSOR_TIME_MIX_A1,
|
| 276 |
+
LLM_TENSOR_TIME_MIX_A2,
|
| 277 |
+
LLM_TENSOR_TIME_MIX_V0,
|
| 278 |
+
LLM_TENSOR_TIME_MIX_V1,
|
| 279 |
+
LLM_TENSOR_TIME_MIX_V2,
|
| 280 |
+
LLM_TENSOR_TIME_MIX_G1,
|
| 281 |
+
LLM_TENSOR_TIME_MIX_G2,
|
| 282 |
+
LLM_TENSOR_TIME_MIX_K_K,
|
| 283 |
+
LLM_TENSOR_TIME_MIX_K_A,
|
| 284 |
+
LLM_TENSOR_TIME_MIX_R_K,
|
| 285 |
LLM_TENSOR_TIME_MIX_LERP_X,
|
| 286 |
LLM_TENSOR_TIME_MIX_LERP_W,
|
| 287 |
LLM_TENSOR_TIME_MIX_LERP_K,
|
|
|
|
| 308 |
LLM_TENSOR_ATTN_Q_B,
|
| 309 |
LLM_TENSOR_ATTN_KV_A_MQA,
|
| 310 |
LLM_TENSOR_ATTN_KV_B,
|
| 311 |
+
LLM_TENSOR_ATTN_K_B,
|
| 312 |
+
LLM_TENSOR_ATTN_V_B,
|
| 313 |
LLM_TENSOR_ATTN_Q_A_NORM,
|
| 314 |
LLM_TENSOR_ATTN_KV_A_NORM,
|
| 315 |
LLM_TENSOR_ATTN_SUB_NORM,
|
examples/talk-llama/llama-batch.h
CHANGED
|
@@ -42,9 +42,9 @@ struct llama_sbatch {
|
|
| 42 |
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
| 43 |
|
| 44 |
// sorted indices into the batch
|
| 45 |
-
std::vector<
|
| 46 |
// batch indices of the output
|
| 47 |
-
std::vector<
|
| 48 |
std::vector<llama_sbatch_seq> seq;
|
| 49 |
|
| 50 |
const llama_batch * batch = nullptr;
|
|
|
|
| 42 |
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
| 43 |
|
| 44 |
// sorted indices into the batch
|
| 45 |
+
std::vector<int64_t> ids;
|
| 46 |
// batch indices of the output
|
| 47 |
+
std::vector<int64_t> out_ids;
|
| 48 |
std::vector<llama_sbatch_seq> seq;
|
| 49 |
|
| 50 |
const llama_batch * batch = nullptr;
|
examples/talk-llama/llama-chat.cpp
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
|
| 5 |
#include <map>
|
| 6 |
#include <sstream>
|
|
|
|
| 7 |
|
| 8 |
#if __cplusplus >= 202000L
|
| 9 |
#define LU8(x) (const char*)(u8##x)
|
|
@@ -58,6 +59,10 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
| 58 |
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
| 59 |
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
| 60 |
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
};
|
| 62 |
|
| 63 |
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
@@ -77,7 +82,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 77 |
if (tmpl_contains("<|im_start|>")) {
|
| 78 |
return tmpl_contains("<|im_sep|>")
|
| 79 |
? LLM_CHAT_TEMPLATE_PHI_4
|
| 80 |
-
:
|
|
|
|
|
|
|
| 81 |
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
| 82 |
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
| 83 |
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
|
@@ -117,6 +124,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 117 |
return LLM_CHAT_TEMPLATE_PHI_3;
|
| 118 |
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
| 119 |
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
|
|
|
|
|
|
| 120 |
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
| 121 |
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
| 122 |
} else if (tmpl_contains("bos_token + message['role']")) {
|
|
@@ -167,6 +176,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 167 |
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
| 168 |
} else if (tmpl_contains("<|role_start|>")) {
|
| 169 |
return LLM_CHAT_TEMPLATE_MEGREZ;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
}
|
| 171 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 172 |
}
|
|
@@ -566,6 +581,66 @@ int32_t llm_chat_apply_template(
|
|
| 566 |
if (add_ass) {
|
| 567 |
ss << "<|role_start|>assistant<|role_end|>";
|
| 568 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
} else {
|
| 570 |
// template not supported
|
| 571 |
return -1;
|
|
@@ -584,4 +659,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
|
|
| 584 |
}
|
| 585 |
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
| 586 |
}
|
| 587 |
-
|
|
|
|
| 4 |
|
| 5 |
#include <map>
|
| 6 |
#include <sstream>
|
| 7 |
+
#include <algorithm>
|
| 8 |
|
| 9 |
#if __cplusplus >= 202000L
|
| 10 |
#define LU8(x) (const char*)(u8##x)
|
|
|
|
| 59 |
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
| 60 |
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
| 61 |
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
| 62 |
+
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
| 63 |
+
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
| 64 |
+
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
| 65 |
+
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
| 66 |
};
|
| 67 |
|
| 68 |
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
|
|
| 82 |
if (tmpl_contains("<|im_start|>")) {
|
| 83 |
return tmpl_contains("<|im_sep|>")
|
| 84 |
? LLM_CHAT_TEMPLATE_PHI_4
|
| 85 |
+
: tmpl_contains("<end_of_utterance>")
|
| 86 |
+
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
|
| 87 |
+
: LLM_CHAT_TEMPLATE_CHATML;
|
| 88 |
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
| 89 |
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
| 90 |
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
|
|
|
| 124 |
return LLM_CHAT_TEMPLATE_PHI_3;
|
| 125 |
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
| 126 |
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
| 127 |
+
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
|
| 128 |
+
return LLM_CHAT_TEMPLATE_GLMEDGE;
|
| 129 |
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
| 130 |
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
| 131 |
} else if (tmpl_contains("bos_token + message['role']")) {
|
|
|
|
| 176 |
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
| 177 |
} else if (tmpl_contains("<|role_start|>")) {
|
| 178 |
return LLM_CHAT_TEMPLATE_MEGREZ;
|
| 179 |
+
} else if (tmpl_contains(" Ассистент:")) {
|
| 180 |
+
return LLM_CHAT_TEMPLATE_YANDEX;
|
| 181 |
+
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
| 182 |
+
return LLM_CHAT_TEMPLATE_BAILING;
|
| 183 |
+
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
| 184 |
+
return LLM_CHAT_TEMPLATE_LLAMA4;
|
| 185 |
}
|
| 186 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 187 |
}
|
|
|
|
| 581 |
if (add_ass) {
|
| 582 |
ss << "<|role_start|>assistant<|role_end|>";
|
| 583 |
}
|
| 584 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
|
| 585 |
+
// Yandex template ("\n\n" is defined as EOT token)
|
| 586 |
+
|
| 587 |
+
ss << "<s>";
|
| 588 |
+
|
| 589 |
+
for (size_t i = 0; i < chat.size(); i++) {
|
| 590 |
+
std::string role(chat[i]->role);
|
| 591 |
+
if (role == "user") {
|
| 592 |
+
ss << " Пользователь: " << chat[i]->content << "\n\n";
|
| 593 |
+
} else if (role == "assistant") {
|
| 594 |
+
ss << " Ассистент: " << chat[i]->content << "\n\n";
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
// Add generation prompt if needed
|
| 599 |
+
if (add_ass) {
|
| 600 |
+
ss << " Ассистент:[SEP]";
|
| 601 |
+
}
|
| 602 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
|
| 603 |
+
// Bailing (Ling) template
|
| 604 |
+
for (auto message : chat) {
|
| 605 |
+
std::string role(message->role);
|
| 606 |
+
|
| 607 |
+
if (role == "user") {
|
| 608 |
+
role = "HUMAN";
|
| 609 |
+
} else {
|
| 610 |
+
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
ss << "<role>" << role << "</role>" << message->content;
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
if (add_ass) {
|
| 617 |
+
ss << "<role>ASSISTANT</role>";
|
| 618 |
+
}
|
| 619 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
|
| 620 |
+
// Llama 4
|
| 621 |
+
for (auto message : chat) {
|
| 622 |
+
std::string role(message->role);
|
| 623 |
+
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
|
| 624 |
+
}
|
| 625 |
+
if (add_ass) {
|
| 626 |
+
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
| 627 |
+
}
|
| 628 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
|
| 629 |
+
// SmolVLM
|
| 630 |
+
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
|
| 631 |
+
for (auto message : chat) {
|
| 632 |
+
std::string role(message->role);
|
| 633 |
+
if (role == "system") {
|
| 634 |
+
ss << message->content << "\n\n";
|
| 635 |
+
} else if (role == "user") {
|
| 636 |
+
ss << "User: " << message->content << "<end_of_utterance>\n";
|
| 637 |
+
} else {
|
| 638 |
+
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
if (add_ass) {
|
| 642 |
+
ss << "Assistant:";
|
| 643 |
+
}
|
| 644 |
} else {
|
| 645 |
// template not supported
|
| 646 |
return -1;
|
|
|
|
| 659 |
}
|
| 660 |
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
| 661 |
}
|
|
|
examples/talk-llama/llama-chat.h
CHANGED
|
@@ -38,6 +38,10 @@ enum llm_chat_template {
|
|
| 38 |
LLM_CHAT_TEMPLATE_GRANITE,
|
| 39 |
LLM_CHAT_TEMPLATE_GIGACHAT,
|
| 40 |
LLM_CHAT_TEMPLATE_MEGREZ,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 42 |
};
|
| 43 |
|
|
|
|
| 38 |
LLM_CHAT_TEMPLATE_GRANITE,
|
| 39 |
LLM_CHAT_TEMPLATE_GIGACHAT,
|
| 40 |
LLM_CHAT_TEMPLATE_MEGREZ,
|
| 41 |
+
LLM_CHAT_TEMPLATE_YANDEX,
|
| 42 |
+
LLM_CHAT_TEMPLATE_BAILING,
|
| 43 |
+
LLM_CHAT_TEMPLATE_LLAMA4,
|
| 44 |
+
LLM_CHAT_TEMPLATE_SMOLVLM,
|
| 45 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 46 |
};
|
| 47 |
|
examples/talk-llama/llama-context.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -3,66 +3,213 @@
|
|
| 3 |
#include "llama.h"
|
| 4 |
#include "llama-batch.h"
|
| 5 |
#include "llama-cparams.h"
|
| 6 |
-
#include "llama-
|
| 7 |
-
#include "llama-kv-cache.h"
|
| 8 |
#include "llama-adapter.h"
|
| 9 |
|
| 10 |
#include "ggml-cpp.h"
|
| 11 |
|
| 12 |
#include <map>
|
| 13 |
-
#include <unordered_map>
|
| 14 |
#include <vector>
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
struct llama_context {
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
|
| 25 |
-
|
| 26 |
-
struct llama_sbatch sbatch; // TODO: revisit if needed
|
| 27 |
-
struct llama_kv_cache kv_self;
|
| 28 |
-
struct llama_adapter_cvec cvec;
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
-
mutable int64_t t_load_us;
|
| 44 |
-
mutable int64_t t_p_eval_us = 0;
|
| 45 |
-
mutable int64_t t_eval_us = 0;
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
size_t logits_size = 0; // capacity (of floats) for logits
|
| 58 |
-
float * logits = nullptr;
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
bool logits_all = false;
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
| 67 |
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
| 68 |
size_t embd_size = 0; // capacity (of floats) for embeddings
|
|
@@ -72,57 +219,47 @@ struct llama_context {
|
|
| 72 |
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
| 73 |
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 74 |
|
| 75 |
-
//
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
| 79 |
-
// number of position id each token get, 1 for each token in most cases.
|
| 80 |
-
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
| 81 |
-
int n_pos_per_token = 1;
|
| 82 |
|
| 83 |
-
//
|
| 84 |
-
std::vector<float> embd_enc;
|
| 85 |
-
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
| 86 |
|
| 87 |
-
// memory buffers used to evaluate the model
|
| 88 |
-
std::vector<uint8_t> buf_compute_meta;
|
| 89 |
ggml_backend_sched_ptr sched;
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
ggml_abort_callback abort_callback = nullptr;
|
| 92 |
void * abort_callback_data = nullptr;
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
| 100 |
-
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
|
| 101 |
-
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
| 102 |
-
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
| 103 |
-
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
| 104 |
-
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
| 105 |
-
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
| 106 |
-
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
| 107 |
-
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
| 108 |
-
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
| 109 |
-
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
| 110 |
-
};
|
| 111 |
|
| 112 |
-
//
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
|
| 119 |
-
//
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
-
//
|
| 127 |
-
//
|
| 128 |
-
|
|
|
|
| 3 |
#include "llama.h"
|
| 4 |
#include "llama-batch.h"
|
| 5 |
#include "llama-cparams.h"
|
| 6 |
+
#include "llama-graph.h"
|
|
|
|
| 7 |
#include "llama-adapter.h"
|
| 8 |
|
| 9 |
#include "ggml-cpp.h"
|
| 10 |
|
| 11 |
#include <map>
|
|
|
|
| 12 |
#include <vector>
|
| 13 |
+
|
| 14 |
+
struct llama_model;
|
| 15 |
+
struct llama_kv_cache;
|
| 16 |
+
|
| 17 |
+
class llama_io_read_i;
|
| 18 |
+
class llama_io_write_i;
|
| 19 |
|
| 20 |
struct llama_context {
|
| 21 |
+
// init scheduler and compute buffers, reserve worst-case graphs
|
| 22 |
+
llama_context(
|
| 23 |
+
const llama_model & model,
|
| 24 |
+
llama_context_params params);
|
| 25 |
|
| 26 |
+
~llama_context();
|
| 27 |
|
| 28 |
+
void synchronize();
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
const llama_model & get_model() const;
|
| 31 |
|
| 32 |
+
uint32_t n_ctx() const;
|
| 33 |
+
uint32_t n_ctx_per_seq() const;
|
| 34 |
+
uint32_t n_batch() const;
|
| 35 |
+
uint32_t n_ubatch() const;
|
| 36 |
+
uint32_t n_seq_max() const;
|
| 37 |
|
| 38 |
+
uint32_t n_threads() const;
|
| 39 |
+
uint32_t n_threads_batch() const;
|
| 40 |
|
| 41 |
+
llama_kv_cache * get_kv_self();
|
| 42 |
+
const llama_kv_cache * get_kv_self() const;
|
| 43 |
|
| 44 |
+
void kv_self_update();
|
| 45 |
|
| 46 |
+
enum llama_pooling_type pooling_type() const;
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
float * get_logits();
|
| 49 |
+
float * get_logits_ith(int32_t i);
|
| 50 |
|
| 51 |
+
float * get_embeddings();
|
| 52 |
+
float * get_embeddings_ith(int32_t i);
|
| 53 |
+
float * get_embeddings_seq(llama_seq_id seq_id);
|
| 54 |
|
| 55 |
+
void attach_threadpool(
|
| 56 |
+
ggml_threadpool_t threadpool,
|
| 57 |
+
ggml_threadpool_t threadpool_batch);
|
| 58 |
|
| 59 |
+
void detach_threadpool();
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
|
| 62 |
+
|
| 63 |
+
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
| 64 |
+
|
| 65 |
+
void set_embeddings (bool value);
|
| 66 |
+
void set_causal_attn(bool value);
|
| 67 |
+
void set_warmup(bool value);
|
| 68 |
+
|
| 69 |
+
void set_adapter_lora(
|
| 70 |
+
llama_adapter_lora * adapter,
|
| 71 |
+
float scale);
|
| 72 |
+
|
| 73 |
+
bool rm_adapter_lora(
|
| 74 |
+
llama_adapter_lora * adapter);
|
| 75 |
+
|
| 76 |
+
void clear_adapter_lora();
|
| 77 |
+
|
| 78 |
+
bool apply_adapter_cvec(
|
| 79 |
+
const float * data,
|
| 80 |
+
size_t len,
|
| 81 |
+
int32_t n_embd,
|
| 82 |
+
int32_t il_start,
|
| 83 |
+
int32_t il_end);
|
| 84 |
+
|
| 85 |
+
int encode(llama_batch & inp_batch);
|
| 86 |
+
int decode(llama_batch & inp_batch);
|
| 87 |
+
|
| 88 |
+
//
|
| 89 |
+
// state save/load
|
| 90 |
+
//
|
| 91 |
+
|
| 92 |
+
size_t state_get_size();
|
| 93 |
+
size_t state_get_data( uint8_t * dst, size_t size);
|
| 94 |
+
size_t state_set_data(const uint8_t * src, size_t size);
|
| 95 |
+
|
| 96 |
+
size_t state_seq_get_size(llama_seq_id seq_id);
|
| 97 |
+
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
| 98 |
+
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
| 99 |
+
|
| 100 |
+
bool state_load_file(
|
| 101 |
+
const char * filepath,
|
| 102 |
+
llama_token * tokens_out,
|
| 103 |
+
size_t n_token_capacity,
|
| 104 |
+
size_t * n_token_count_out);
|
| 105 |
+
|
| 106 |
+
bool state_save_file(
|
| 107 |
+
const char * filepath,
|
| 108 |
+
const llama_token * tokens,
|
| 109 |
+
size_t n_token_count);
|
| 110 |
+
|
| 111 |
+
size_t state_seq_load_file(
|
| 112 |
+
llama_seq_id seq_id,
|
| 113 |
+
const char * filepath,
|
| 114 |
+
llama_token * tokens_out,
|
| 115 |
+
size_t n_token_capacity,
|
| 116 |
+
size_t * n_token_count_out);
|
| 117 |
+
|
| 118 |
+
size_t state_seq_save_file(
|
| 119 |
+
llama_seq_id seq_id,
|
| 120 |
+
const char * filepath,
|
| 121 |
+
const llama_token * tokens,
|
| 122 |
+
size_t n_token_count);
|
| 123 |
+
|
| 124 |
+
//
|
| 125 |
+
// perf
|
| 126 |
+
//
|
| 127 |
+
|
| 128 |
+
llama_perf_context_data perf_get_data() const;
|
| 129 |
+
void perf_reset();
|
| 130 |
+
|
| 131 |
+
private:
|
| 132 |
+
//
|
| 133 |
+
// output
|
| 134 |
+
//
|
| 135 |
|
| 136 |
+
// Make sure enough space is available for outputs.
|
| 137 |
+
// Returns max number of outputs for which space was reserved.
|
| 138 |
+
int32_t output_reserve(int32_t n_outputs);
|
| 139 |
+
|
| 140 |
+
// make the outputs have the same order they had in the user-provided batch
|
| 141 |
+
// TODO: maybe remove this
|
| 142 |
+
void output_reorder();
|
| 143 |
+
|
| 144 |
+
//
|
| 145 |
+
// graph
|
| 146 |
+
//
|
| 147 |
+
|
| 148 |
+
int32_t graph_max_nodes() const;
|
| 149 |
+
|
| 150 |
+
// zero-out inputs and create the ctx_compute for the compute graph
|
| 151 |
+
ggml_cgraph * graph_init();
|
| 152 |
+
|
| 153 |
+
llm_graph_result_ptr graph_build(
|
| 154 |
+
ggml_context * ctx,
|
| 155 |
+
ggml_cgraph * gf,
|
| 156 |
+
const llama_ubatch & ubatch,
|
| 157 |
+
llm_graph_type gtype);
|
| 158 |
+
|
| 159 |
+
// returns the result of ggml_backend_sched_graph_compute_async execution
|
| 160 |
+
ggml_status graph_compute(
|
| 161 |
+
ggml_cgraph * gf,
|
| 162 |
+
bool batched);
|
| 163 |
+
|
| 164 |
+
llm_graph_cb graph_get_cb() const;
|
| 165 |
+
|
| 166 |
+
// used by kv_self_update()
|
| 167 |
+
ggml_tensor * build_rope_shift(
|
| 168 |
+
ggml_context * ctx0,
|
| 169 |
+
ggml_tensor * cur,
|
| 170 |
+
ggml_tensor * shift,
|
| 171 |
+
ggml_tensor * factors,
|
| 172 |
+
float freq_base,
|
| 173 |
+
float freq_scale,
|
| 174 |
+
ggml_backend_buffer * bbuf) const;
|
| 175 |
+
|
| 176 |
+
llm_graph_result_ptr build_kv_self_shift(
|
| 177 |
+
ggml_context * ctx0,
|
| 178 |
+
ggml_cgraph * gf) const;
|
| 179 |
+
|
| 180 |
+
llm_graph_result_ptr build_kv_self_defrag(
|
| 181 |
+
ggml_context * ctx0,
|
| 182 |
+
ggml_cgraph * gf) const;
|
| 183 |
+
|
| 184 |
+
// TODO: read/write lora adapters and cvec
|
| 185 |
+
size_t state_write_data(llama_io_write_i & io);
|
| 186 |
+
size_t state_read_data (llama_io_read_i & io);
|
| 187 |
+
|
| 188 |
+
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
| 189 |
+
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
| 190 |
+
|
| 191 |
+
//
|
| 192 |
+
// members
|
| 193 |
+
//
|
| 194 |
+
|
| 195 |
+
const llama_model & model;
|
| 196 |
+
|
| 197 |
+
llama_cparams cparams;
|
| 198 |
+
llama_adapter_cvec cvec;
|
| 199 |
+
llama_adapter_loras loras;
|
| 200 |
+
llama_sbatch sbatch;
|
| 201 |
+
|
| 202 |
+
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
| 203 |
+
|
| 204 |
+
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
| 205 |
+
|
| 206 |
+
// TODO: remove
|
| 207 |
bool logits_all = false;
|
| 208 |
|
| 209 |
+
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
| 210 |
+
size_t logits_size = 0; // capacity (of floats) for logits
|
| 211 |
+
float * logits = nullptr;
|
| 212 |
+
|
| 213 |
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
| 214 |
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
| 215 |
size_t embd_size = 0; // capacity (of floats) for embeddings
|
|
|
|
| 219 |
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
| 220 |
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 221 |
|
| 222 |
+
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
| 223 |
+
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
|
|
|
|
|
|
| 226 |
|
|
|
|
|
|
|
| 227 |
ggml_backend_sched_ptr sched;
|
| 228 |
|
| 229 |
+
ggml_backend_t backend_cpu = nullptr;
|
| 230 |
+
std::vector<ggml_backend_ptr> backends;
|
| 231 |
+
|
| 232 |
+
ggml_context_ptr ctx_compute;
|
| 233 |
+
|
| 234 |
+
ggml_threadpool_t threadpool = nullptr;
|
| 235 |
+
ggml_threadpool_t threadpool_batch = nullptr;
|
| 236 |
+
|
| 237 |
ggml_abort_callback abort_callback = nullptr;
|
| 238 |
void * abort_callback_data = nullptr;
|
| 239 |
|
| 240 |
+
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
| 241 |
+
|
| 242 |
+
// buffer types used for the compute buffer of each backend
|
| 243 |
+
std::vector<ggml_backend_t> backend_ptrs;
|
| 244 |
+
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
// memory buffers used to evaluate the model
|
| 247 |
+
std::vector<uint8_t> buf_compute_meta;
|
| 248 |
|
| 249 |
+
// host buffer for the model output (logits and embeddings)
|
| 250 |
+
ggml_backend_buffer_ptr buf_output;
|
| 251 |
|
| 252 |
+
bool has_evaluated_once = false;
|
| 253 |
|
| 254 |
+
// perf
|
| 255 |
+
mutable int64_t t_start_us = 0;
|
| 256 |
+
mutable int64_t t_load_us = 0;
|
| 257 |
+
mutable int64_t t_p_eval_us = 0;
|
| 258 |
+
mutable int64_t t_eval_us = 0;
|
| 259 |
|
| 260 |
+
mutable int64_t t_compute_start_us = 0;
|
| 261 |
+
mutable int64_t n_queued_tokens = 0;
|
| 262 |
|
| 263 |
+
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
| 264 |
+
mutable int32_t n_eval = 0; // number of eval calls
|
| 265 |
+
};
|
examples/talk-llama/llama-cparams.h
CHANGED
|
@@ -29,6 +29,7 @@ struct llama_cparams {
|
|
| 29 |
bool offload_kqv;
|
| 30 |
bool flash_attn;
|
| 31 |
bool no_perf;
|
|
|
|
| 32 |
|
| 33 |
enum llama_pooling_type pooling_type;
|
| 34 |
|
|
|
|
| 29 |
bool offload_kqv;
|
| 30 |
bool flash_attn;
|
| 31 |
bool no_perf;
|
| 32 |
+
bool warmup;
|
| 33 |
|
| 34 |
enum llama_pooling_type pooling_type;
|
| 35 |
|
examples/talk-llama/llama-grammar.cpp
CHANGED
|
@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
|
|
| 345 |
size_t last_sym_start = rule.size();
|
| 346 |
const char * pos = src;
|
| 347 |
|
| 348 |
-
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
}
|
| 379 |
}
|
|
|
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
}
|
| 391 |
-
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
| 392 |
-
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
| 393 |
-
add_rule( rec_rule_id, rec_rule);
|
| 394 |
-
last_rec_rule_id = rec_rule_id;
|
| 395 |
}
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
}
|
| 409 |
-
auto char_pair = parse_char(pos);
|
| 410 |
-
pos = char_pair.second;
|
| 411 |
-
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
| 412 |
}
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
pos++;
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
| 420 |
}
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
throw std::runtime_error("unexpected end of input");
|
| 425 |
}
|
| 426 |
-
auto
|
| 427 |
-
pos
|
| 428 |
-
|
| 429 |
-
? LLAMA_GRETYPE_CHAR_ALT
|
| 430 |
-
: start_type;
|
| 431 |
-
|
| 432 |
-
rule.push_back({type, char_pair.first});
|
| 433 |
-
if (pos[0] == '-' && pos[1] != ']') {
|
| 434 |
-
if (!pos[1]) {
|
| 435 |
-
throw std::runtime_error("unexpected end of input");
|
| 436 |
-
}
|
| 437 |
-
auto endchar_pair = parse_char(pos + 1);
|
| 438 |
-
pos = endchar_pair.second;
|
| 439 |
-
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
| 440 |
-
}
|
| 441 |
-
}
|
| 442 |
-
pos = parse_space(pos + 1, is_nested);
|
| 443 |
-
} else if (is_word_char(*pos)) { // rule reference
|
| 444 |
-
const char * name_end = parse_name(pos);
|
| 445 |
-
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
| 446 |
-
pos = parse_space(name_end, is_nested);
|
| 447 |
-
last_sym_start = rule.size();
|
| 448 |
-
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
| 449 |
-
} else if (*pos == '(') { // grouping
|
| 450 |
-
// parse nested alternates into synthesized rule
|
| 451 |
-
pos = parse_space(pos + 1, true);
|
| 452 |
-
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
| 453 |
-
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
| 454 |
-
last_sym_start = rule.size();
|
| 455 |
-
// output reference to synthesized rule
|
| 456 |
-
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
| 457 |
-
if (*pos != ')') {
|
| 458 |
-
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
| 459 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
pos = parse_space(pos + 1, is_nested);
|
| 461 |
-
} else if (*pos == '
|
| 462 |
-
last_sym_start = rule.size();
|
| 463 |
-
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
| 464 |
-
pos = parse_space(pos + 1, is_nested);
|
| 465 |
-
} else if (*pos == '*') {
|
| 466 |
-
pos = parse_space(pos + 1, is_nested);
|
| 467 |
-
handle_repetitions(0, -1);
|
| 468 |
-
} else if (*pos == '+') {
|
| 469 |
-
pos = parse_space(pos + 1, is_nested);
|
| 470 |
-
handle_repetitions(1, -1);
|
| 471 |
-
} else if (*pos == '?') {
|
| 472 |
-
pos = parse_space(pos + 1, is_nested);
|
| 473 |
-
handle_repetitions(0, 1);
|
| 474 |
-
} else if (*pos == '{') {
|
| 475 |
pos = parse_space(pos + 1, is_nested);
|
| 476 |
|
| 477 |
-
if (
|
| 478 |
-
|
|
|
|
|
|
|
| 479 |
}
|
| 480 |
-
const char * int_end = parse_int(pos);
|
| 481 |
-
int min_times = std::stoul(std::string(pos, int_end - pos));
|
| 482 |
-
pos = parse_space(int_end, is_nested);
|
| 483 |
-
|
| 484 |
-
int max_times = -1;
|
| 485 |
-
|
| 486 |
-
if (*pos == '}') {
|
| 487 |
-
max_times = min_times;
|
| 488 |
-
pos = parse_space(pos + 1, is_nested);
|
| 489 |
-
} else if (*pos == ',') {
|
| 490 |
-
pos = parse_space(pos + 1, is_nested);
|
| 491 |
-
|
| 492 |
-
if (is_digit_char(*pos)) {
|
| 493 |
-
const char * int_end = parse_int(pos);
|
| 494 |
-
max_times = std::stoul(std::string(pos, int_end - pos));
|
| 495 |
-
pos = parse_space(int_end, is_nested);
|
| 496 |
-
}
|
| 497 |
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
}
|
| 501 |
-
pos = parse_space(pos + 1, is_nested);
|
| 502 |
-
} else {
|
| 503 |
-
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
| 504 |
}
|
| 505 |
-
|
| 506 |
} else {
|
| 507 |
-
|
| 508 |
}
|
|
|
|
|
|
|
|
|
|
| 509 |
}
|
| 510 |
-
return pos;
|
| 511 |
}
|
|
|
|
|
|
|
| 512 |
|
| 513 |
const char * llama_grammar_parser::parse_rule(const char * src) {
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
|
| 525 |
-
|
| 526 |
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
}
|
| 534 |
-
return parse_space(pos, true);
|
| 535 |
}
|
|
|
|
|
|
|
| 536 |
|
| 537 |
bool llama_grammar_parser::parse(const char * src) {
|
| 538 |
try {
|
|
@@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
| 969 |
/* .awaiting_trigger = */ false,
|
| 970 |
/* .trigger_buffer = */ "",
|
| 971 |
/* .trigger_tokens = */ {},
|
| 972 |
-
/* .
|
| 973 |
};
|
| 974 |
}
|
| 975 |
|
|
@@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
| 978 |
const char * grammar_str,
|
| 979 |
const char * grammar_root,
|
| 980 |
bool lazy,
|
| 981 |
-
const char **
|
| 982 |
-
size_t
|
| 983 |
const llama_token * trigger_tokens,
|
| 984 |
size_t num_trigger_tokens) {
|
| 985 |
llama_grammar_parser parser;
|
| 986 |
|
| 987 |
// if there is a grammar, parse it
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
}
|
| 991 |
-
|
| 992 |
-
// will be empty (default) if there are parse errors
|
| 993 |
-
if (parser.rules.empty()) {
|
| 994 |
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
| 995 |
return nullptr;
|
| 996 |
}
|
|
@@ -1054,14 +1050,16 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
| 1054 |
} while (true);
|
| 1055 |
|
| 1056 |
std::vector<llama_token> vec_trigger_tokens;
|
| 1057 |
-
std::vector<
|
| 1058 |
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
| 1059 |
GGML_ASSERT(trigger_tokens != nullptr);
|
| 1060 |
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
| 1061 |
}
|
| 1062 |
-
for (size_t i = 0; i <
|
| 1063 |
-
GGML_ASSERT(
|
| 1064 |
-
|
|
|
|
|
|
|
| 1065 |
}
|
| 1066 |
|
| 1067 |
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
|
@@ -1076,7 +1074,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
| 1076 |
/* .awaiting_trigger = */ lazy,
|
| 1077 |
/* .trigger_buffer = */ "",
|
| 1078 |
std::move(vec_trigger_tokens),
|
| 1079 |
-
std::move(
|
| 1080 |
};
|
| 1081 |
}
|
| 1082 |
|
|
@@ -1089,7 +1087,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
|
| 1089 |
}
|
| 1090 |
|
| 1091 |
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
| 1092 |
-
|
| 1093 |
grammar.vocab,
|
| 1094 |
grammar.rules,
|
| 1095 |
grammar.stacks,
|
|
@@ -1098,7 +1096,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|
| 1098 |
grammar.awaiting_trigger,
|
| 1099 |
grammar.trigger_buffer,
|
| 1100 |
grammar.trigger_tokens,
|
| 1101 |
-
grammar.
|
| 1102 |
};
|
| 1103 |
|
| 1104 |
// redirect elements in stacks to point to new rules
|
|
@@ -1173,20 +1171,22 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|
| 1173 |
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
| 1174 |
return;
|
| 1175 |
} else {
|
| 1176 |
-
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
| 1177 |
grammar.trigger_buffer += piece;
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
|
|
|
| 1181 |
grammar.awaiting_trigger = false;
|
| 1182 |
-
|
|
|
|
|
|
|
| 1183 |
grammar.trigger_buffer.clear();
|
| 1184 |
llama_grammar_accept_str(grammar, constrained_str);
|
| 1185 |
-
LLAMA_LOG_DEBUG("Grammar triggered on
|
| 1186 |
return;
|
| 1187 |
}
|
| 1188 |
}
|
| 1189 |
-
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)
|
| 1190 |
return;
|
| 1191 |
}
|
| 1192 |
}
|
|
|
|
| 345 |
size_t last_sym_start = rule.size();
|
| 346 |
const char * pos = src;
|
| 347 |
|
| 348 |
+
auto handle_repetitions = [&](int min_times, int max_times) {
|
| 349 |
|
| 350 |
+
if (last_sym_start == rule.size()) {
|
| 351 |
+
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
| 352 |
+
}
|
| 353 |
|
| 354 |
+
// apply transformation to previous symbol (last_sym_start to end) according to
|
| 355 |
+
// the following rewrite rules:
|
| 356 |
+
// S{m,n} --> S S S (m times) S'(n-m)
|
| 357 |
+
// S'(x) ::= S S'(x-1) |
|
| 358 |
+
// (... n-m definitions of these S' rules ...)
|
| 359 |
+
// S'(1) ::= S |
|
| 360 |
+
// S{m,} --> S S S (m times) S'
|
| 361 |
+
// S' ::= S S' |
|
| 362 |
+
// S* --> S{0,}
|
| 363 |
+
// --> S' ::= S S' |
|
| 364 |
+
// S+ --> S{1,}
|
| 365 |
+
// --> S S'
|
| 366 |
+
// S' ::= S S' |
|
| 367 |
+
// S? --> S{0,1}
|
| 368 |
+
// --> S'
|
| 369 |
+
// S' ::= S |
|
| 370 |
+
|
| 371 |
+
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
| 372 |
+
if (min_times == 0) {
|
| 373 |
+
rule.resize(last_sym_start);
|
| 374 |
+
} else {
|
| 375 |
+
// Repeat the previous elements (min_times - 1) times
|
| 376 |
+
for (int i = 1; i < min_times; i++) {
|
| 377 |
+
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
|
|
|
| 378 |
}
|
| 379 |
+
}
|
| 380 |
|
| 381 |
+
uint32_t last_rec_rule_id = 0;
|
| 382 |
+
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
| 383 |
|
| 384 |
+
llama_grammar_rule rec_rule(prev_rule);
|
| 385 |
+
for (int i = 0; i < n_opt; i++) {
|
| 386 |
+
rec_rule.resize(prev_rule.size());
|
| 387 |
+
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
| 388 |
+
if (i > 0 || max_times < 0) {
|
| 389 |
+
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
}
|
| 391 |
+
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
| 392 |
+
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
| 393 |
+
add_rule( rec_rule_id, rec_rule);
|
| 394 |
+
last_rec_rule_id = rec_rule_id;
|
| 395 |
+
}
|
| 396 |
+
if (n_opt > 0) {
|
| 397 |
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
| 398 |
+
}
|
| 399 |
+
};
|
| 400 |
|
| 401 |
+
while (*pos) {
|
| 402 |
+
if (*pos == '"') { // literal string
|
| 403 |
+
pos++;
|
| 404 |
+
last_sym_start = rule.size();
|
| 405 |
+
while (*pos != '"') {
|
| 406 |
+
if (!*pos) {
|
| 407 |
+
throw std::runtime_error("unexpected end of input");
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
}
|
| 409 |
+
auto char_pair = parse_char(pos);
|
| 410 |
+
pos = char_pair.second;
|
| 411 |
+
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
| 412 |
+
}
|
| 413 |
+
pos = parse_space(pos + 1, is_nested);
|
| 414 |
+
} else if (*pos == '[') { // char range(s)
|
| 415 |
+
pos++;
|
| 416 |
+
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
| 417 |
+
if (*pos == '^') {
|
| 418 |
pos++;
|
| 419 |
+
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
| 420 |
+
}
|
| 421 |
+
last_sym_start = rule.size();
|
| 422 |
+
while (*pos != ']') {
|
| 423 |
+
if (!*pos) {
|
| 424 |
+
throw std::runtime_error("unexpected end of input");
|
| 425 |
}
|
| 426 |
+
auto char_pair = parse_char(pos);
|
| 427 |
+
pos = char_pair.second;
|
| 428 |
+
enum llama_gretype type = last_sym_start < rule.size()
|
| 429 |
+
? LLAMA_GRETYPE_CHAR_ALT
|
| 430 |
+
: start_type;
|
| 431 |
+
|
| 432 |
+
rule.push_back({type, char_pair.first});
|
| 433 |
+
if (pos[0] == '-' && pos[1] != ']') {
|
| 434 |
+
if (!pos[1]) {
|
| 435 |
throw std::runtime_error("unexpected end of input");
|
| 436 |
}
|
| 437 |
+
auto endchar_pair = parse_char(pos + 1);
|
| 438 |
+
pos = endchar_pair.second;
|
| 439 |
+
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
}
|
| 441 |
+
}
|
| 442 |
+
pos = parse_space(pos + 1, is_nested);
|
| 443 |
+
} else if (is_word_char(*pos)) { // rule reference
|
| 444 |
+
const char * name_end = parse_name(pos);
|
| 445 |
+
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
| 446 |
+
pos = parse_space(name_end, is_nested);
|
| 447 |
+
last_sym_start = rule.size();
|
| 448 |
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
| 449 |
+
} else if (*pos == '(') { // grouping
|
| 450 |
+
// parse nested alternates into synthesized rule
|
| 451 |
+
pos = parse_space(pos + 1, true);
|
| 452 |
+
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
| 453 |
+
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
| 454 |
+
last_sym_start = rule.size();
|
| 455 |
+
// output reference to synthesized rule
|
| 456 |
+
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
| 457 |
+
if (*pos != ')') {
|
| 458 |
+
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
| 459 |
+
}
|
| 460 |
+
pos = parse_space(pos + 1, is_nested);
|
| 461 |
+
} else if (*pos == '.') { // any char
|
| 462 |
+
last_sym_start = rule.size();
|
| 463 |
+
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
| 464 |
+
pos = parse_space(pos + 1, is_nested);
|
| 465 |
+
} else if (*pos == '*') {
|
| 466 |
+
pos = parse_space(pos + 1, is_nested);
|
| 467 |
+
handle_repetitions(0, -1);
|
| 468 |
+
} else if (*pos == '+') {
|
| 469 |
+
pos = parse_space(pos + 1, is_nested);
|
| 470 |
+
handle_repetitions(1, -1);
|
| 471 |
+
} else if (*pos == '?') {
|
| 472 |
+
pos = parse_space(pos + 1, is_nested);
|
| 473 |
+
handle_repetitions(0, 1);
|
| 474 |
+
} else if (*pos == '{') {
|
| 475 |
+
pos = parse_space(pos + 1, is_nested);
|
| 476 |
+
|
| 477 |
+
if (!is_digit_char(*pos)) {
|
| 478 |
+
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
| 479 |
+
}
|
| 480 |
+
const char * int_end = parse_int(pos);
|
| 481 |
+
int min_times = std::stoul(std::string(pos, int_end - pos));
|
| 482 |
+
pos = parse_space(int_end, is_nested);
|
| 483 |
+
|
| 484 |
+
int max_times = -1;
|
| 485 |
+
|
| 486 |
+
if (*pos == '}') {
|
| 487 |
+
max_times = min_times;
|
| 488 |
pos = parse_space(pos + 1, is_nested);
|
| 489 |
+
} else if (*pos == ',') {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
pos = parse_space(pos + 1, is_nested);
|
| 491 |
|
| 492 |
+
if (is_digit_char(*pos)) {
|
| 493 |
+
const char * int_end = parse_int(pos);
|
| 494 |
+
max_times = std::stoul(std::string(pos, int_end - pos));
|
| 495 |
+
pos = parse_space(int_end, is_nested);
|
| 496 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
|
| 498 |
+
if (*pos != '}') {
|
| 499 |
+
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
}
|
| 501 |
+
pos = parse_space(pos + 1, is_nested);
|
| 502 |
} else {
|
| 503 |
+
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
| 504 |
}
|
| 505 |
+
handle_repetitions(min_times, max_times);
|
| 506 |
+
} else {
|
| 507 |
+
break;
|
| 508 |
}
|
|
|
|
| 509 |
}
|
| 510 |
+
return pos;
|
| 511 |
+
}
|
| 512 |
|
| 513 |
const char * llama_grammar_parser::parse_rule(const char * src) {
|
| 514 |
+
const char * name_end = parse_name(src);
|
| 515 |
+
const char * pos = parse_space(name_end, false);
|
| 516 |
+
size_t name_len = name_end - src;
|
| 517 |
+
uint32_t rule_id = get_symbol_id(src, name_len);
|
| 518 |
+
const std::string name(src, name_len);
|
| 519 |
+
|
| 520 |
+
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
| 521 |
+
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
| 522 |
+
}
|
| 523 |
+
pos = parse_space(pos + 3, true);
|
| 524 |
|
| 525 |
+
pos = parse_alternates(pos, name, rule_id, false);
|
| 526 |
|
| 527 |
+
if (*pos == '\r') {
|
| 528 |
+
pos += pos[1] == '\n' ? 2 : 1;
|
| 529 |
+
} else if (*pos == '\n') {
|
| 530 |
+
pos++;
|
| 531 |
+
} else if (*pos) {
|
| 532 |
+
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
|
|
|
|
|
|
| 533 |
}
|
| 534 |
+
return parse_space(pos, true);
|
| 535 |
+
}
|
| 536 |
|
| 537 |
bool llama_grammar_parser::parse(const char * src) {
|
| 538 |
try {
|
|
|
|
| 969 |
/* .awaiting_trigger = */ false,
|
| 970 |
/* .trigger_buffer = */ "",
|
| 971 |
/* .trigger_tokens = */ {},
|
| 972 |
+
/* .trigger_patterns = */ {},
|
| 973 |
};
|
| 974 |
}
|
| 975 |
|
|
|
|
| 978 |
const char * grammar_str,
|
| 979 |
const char * grammar_root,
|
| 980 |
bool lazy,
|
| 981 |
+
const char ** trigger_patterns,
|
| 982 |
+
size_t num_trigger_patterns,
|
| 983 |
const llama_token * trigger_tokens,
|
| 984 |
size_t num_trigger_tokens) {
|
| 985 |
llama_grammar_parser parser;
|
| 986 |
|
| 987 |
// if there is a grammar, parse it
|
| 988 |
+
// rules will be empty (default) if there are parse errors
|
| 989 |
+
if (!parser.parse(grammar_str) || parser.rules.empty()) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 990 |
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
| 991 |
return nullptr;
|
| 992 |
}
|
|
|
|
| 1050 |
} while (true);
|
| 1051 |
|
| 1052 |
std::vector<llama_token> vec_trigger_tokens;
|
| 1053 |
+
std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
|
| 1054 |
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
| 1055 |
GGML_ASSERT(trigger_tokens != nullptr);
|
| 1056 |
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
| 1057 |
}
|
| 1058 |
+
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
| 1059 |
+
GGML_ASSERT(trigger_patterns != nullptr);
|
| 1060 |
+
auto & trigger = vec_trigger_patterns.emplace_back();
|
| 1061 |
+
trigger.pattern = trigger_patterns[i];
|
| 1062 |
+
trigger.regex = std::regex(trigger.pattern);
|
| 1063 |
}
|
| 1064 |
|
| 1065 |
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
|
|
|
| 1074 |
/* .awaiting_trigger = */ lazy,
|
| 1075 |
/* .trigger_buffer = */ "",
|
| 1076 |
std::move(vec_trigger_tokens),
|
| 1077 |
+
std::move(vec_trigger_patterns),
|
| 1078 |
};
|
| 1079 |
}
|
| 1080 |
|
|
|
|
| 1087 |
}
|
| 1088 |
|
| 1089 |
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
| 1090 |
+
auto * result = new llama_grammar {
|
| 1091 |
grammar.vocab,
|
| 1092 |
grammar.rules,
|
| 1093 |
grammar.stacks,
|
|
|
|
| 1096 |
grammar.awaiting_trigger,
|
| 1097 |
grammar.trigger_buffer,
|
| 1098 |
grammar.trigger_tokens,
|
| 1099 |
+
grammar.trigger_patterns,
|
| 1100 |
};
|
| 1101 |
|
| 1102 |
// redirect elements in stacks to point to new rules
|
|
|
|
| 1171 |
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
| 1172 |
return;
|
| 1173 |
} else {
|
|
|
|
| 1174 |
grammar.trigger_buffer += piece;
|
| 1175 |
+
|
| 1176 |
+
std::smatch match;
|
| 1177 |
+
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
| 1178 |
+
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
| 1179 |
grammar.awaiting_trigger = false;
|
| 1180 |
+
// get from the first match to the end of the string
|
| 1181 |
+
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
| 1182 |
+
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
| 1183 |
grammar.trigger_buffer.clear();
|
| 1184 |
llama_grammar_accept_str(grammar, constrained_str);
|
| 1185 |
+
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
| 1186 |
return;
|
| 1187 |
}
|
| 1188 |
}
|
| 1189 |
+
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
|
| 1190 |
return;
|
| 1191 |
}
|
| 1192 |
}
|
examples/talk-llama/llama-grammar.h
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
#include <map>
|
|
|
|
| 6 |
#include <string>
|
| 7 |
#include <vector>
|
| 8 |
|
|
@@ -105,6 +106,11 @@ struct llama_grammar_parser {
|
|
| 105 |
void print(FILE * file);
|
| 106 |
};
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
struct llama_grammar {
|
| 109 |
// note: allow null vocab for testing (not great)
|
| 110 |
const llama_vocab * vocab;
|
|
@@ -116,13 +122,16 @@ struct llama_grammar {
|
|
| 116 |
llama_partial_utf8 partial_utf8;
|
| 117 |
|
| 118 |
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
| 119 |
-
// we still
|
| 120 |
// (useful e.g. for tool_choice=required)
|
| 121 |
bool lazy = false;
|
| 122 |
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
| 123 |
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
| 124 |
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
| 125 |
-
std::vector<
|
|
|
|
|
|
|
|
|
|
| 126 |
};
|
| 127 |
|
| 128 |
//
|
|
@@ -141,8 +150,8 @@ struct llama_grammar * llama_grammar_init_impl(
|
|
| 141 |
const char * grammar_str,
|
| 142 |
const char * grammar_root,
|
| 143 |
bool lazy,
|
| 144 |
-
const char **
|
| 145 |
-
size_t
|
| 146 |
const llama_token * trigger_tokens,
|
| 147 |
size_t num_trigger_tokens);
|
| 148 |
|
|
|
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
#include <map>
|
| 6 |
+
#include <regex>
|
| 7 |
#include <string>
|
| 8 |
#include <vector>
|
| 9 |
|
|
|
|
| 106 |
void print(FILE * file);
|
| 107 |
};
|
| 108 |
|
| 109 |
+
struct llama_grammar_trigger_pattern {
|
| 110 |
+
std::string pattern;
|
| 111 |
+
std::regex regex;
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
struct llama_grammar {
|
| 115 |
// note: allow null vocab for testing (not great)
|
| 116 |
const llama_vocab * vocab;
|
|
|
|
| 122 |
llama_partial_utf8 partial_utf8;
|
| 123 |
|
| 124 |
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
| 125 |
+
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
| 126 |
// (useful e.g. for tool_choice=required)
|
| 127 |
bool lazy = false;
|
| 128 |
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
| 129 |
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
| 130 |
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
| 131 |
+
std::vector<llama_grammar_trigger_pattern>
|
| 132 |
+
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
| 133 |
+
// string, and the grammar will be given the string from the first match group onwards.
|
| 134 |
+
|
| 135 |
};
|
| 136 |
|
| 137 |
//
|
|
|
|
| 150 |
const char * grammar_str,
|
| 151 |
const char * grammar_root,
|
| 152 |
bool lazy,
|
| 153 |
+
const char ** trigger_patterns,
|
| 154 |
+
size_t num_trigger_patterns,
|
| 155 |
const llama_token * trigger_tokens,
|
| 156 |
size_t num_trigger_tokens);
|
| 157 |
|
examples/talk-llama/llama-graph.cpp
ADDED
|
@@ -0,0 +1,1706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-graph.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
+
#include "llama-cparams.h"
|
| 6 |
+
#include "llama-kv-cache.h"
|
| 7 |
+
|
| 8 |
+
#include <cassert>
|
| 9 |
+
#include <cmath>
|
| 10 |
+
#include <cstring>
|
| 11 |
+
|
| 12 |
+
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
| 13 |
+
// TODO move to hparams if a T5 variant appears that uses a different value
|
| 14 |
+
const int64_t max_distance = 128;
|
| 15 |
+
|
| 16 |
+
if (bidirectional) {
|
| 17 |
+
n_buckets >>= 1;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
const int64_t max_exact = n_buckets >> 1;
|
| 21 |
+
|
| 22 |
+
int32_t relative_position = x - y;
|
| 23 |
+
int32_t relative_bucket = 0;
|
| 24 |
+
|
| 25 |
+
if (bidirectional) {
|
| 26 |
+
relative_bucket += (relative_position > 0) * n_buckets;
|
| 27 |
+
relative_position = abs(relative_position);
|
| 28 |
+
} else {
|
| 29 |
+
relative_position = -std::min<int32_t>(relative_position, 0);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
| 33 |
+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
| 34 |
+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
| 35 |
+
|
| 36 |
+
return relative_bucket;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
| 40 |
+
if (ubatch->token) {
|
| 41 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 42 |
+
|
| 43 |
+
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
if (ubatch->embd) {
|
| 47 |
+
const int64_t n_embd = embd->ne[0];
|
| 48 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 49 |
+
|
| 50 |
+
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
| 55 |
+
if (ubatch->pos && pos) {
|
| 56 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 57 |
+
|
| 58 |
+
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
| 63 |
+
if (ubatch->pos && attn_scale) {
|
| 64 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 65 |
+
|
| 66 |
+
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
| 67 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 68 |
+
const float pos = ubatch->pos[i];
|
| 69 |
+
attn_scale_data[i] = std::log(
|
| 70 |
+
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
|
| 71 |
+
) * f_attn_temp_scale + 1.0;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
| 79 |
+
if (pos_bucket) {
|
| 80 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 81 |
+
|
| 82 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
| 83 |
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 84 |
+
|
| 85 |
+
int32_t * data = (int32_t *) pos_bucket->data;
|
| 86 |
+
|
| 87 |
+
for (int h = 0; h < 1; ++h) {
|
| 88 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 89 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 90 |
+
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 98 |
+
if (pos_bucket) {
|
| 99 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 100 |
+
|
| 101 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
| 102 |
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 103 |
+
|
| 104 |
+
int32_t * data = (int32_t *) pos_bucket->data;
|
| 105 |
+
|
| 106 |
+
const int64_t n_kv = kv_self->n;
|
| 107 |
+
|
| 108 |
+
for (int h = 0; h < 1; ++h) {
|
| 109 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 110 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 111 |
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
| 119 |
+
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
| 120 |
+
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
|
| 121 |
+
|
| 122 |
+
if (!out_ids) {
|
| 123 |
+
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
|
| 124 |
+
} else {
|
| 125 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 126 |
+
|
| 127 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
| 128 |
+
int32_t * data = (int32_t *) out_ids->data;
|
| 129 |
+
|
| 130 |
+
if (n_outputs == n_tokens) {
|
| 131 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 132 |
+
data[i] = i;
|
| 133 |
+
}
|
| 134 |
+
} else if (ubatch->output) {
|
| 135 |
+
int32_t n_outputs = 0;
|
| 136 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 137 |
+
if (ubatch->output[i]) {
|
| 138 |
+
data[n_outputs++] = i;
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
// the graph needs to have been passed the correct number of outputs
|
| 142 |
+
GGML_ASSERT(n_outputs == n_outputs);
|
| 143 |
+
} else if (n_outputs == 1) {
|
| 144 |
+
// only keep last output
|
| 145 |
+
data[0] = n_tokens - 1;
|
| 146 |
+
} else {
|
| 147 |
+
GGML_ASSERT(n_outputs == 0);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
| 154 |
+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
| 155 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 156 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 157 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 158 |
+
|
| 159 |
+
GGML_ASSERT(mean);
|
| 160 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
| 161 |
+
|
| 162 |
+
float * data = (float *) mean->data;
|
| 163 |
+
memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
|
| 164 |
+
|
| 165 |
+
std::vector<uint64_t> sum(n_tokens, 0);
|
| 166 |
+
|
| 167 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 168 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 169 |
+
|
| 170 |
+
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
| 171 |
+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
| 172 |
+
|
| 173 |
+
sum[seq_id] += ubatch->n_seq_tokens;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
std::vector<float> div(n_tokens, 0.0f);
|
| 177 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 178 |
+
const uint64_t s = sum[i];
|
| 179 |
+
if (s > 0) {
|
| 180 |
+
div[i] = 1.0f/float(s);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 185 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 186 |
+
|
| 187 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 188 |
+
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
| 195 |
+
if (cparams.embeddings && (
|
| 196 |
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
| 197 |
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
|
| 198 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 199 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 200 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 201 |
+
|
| 202 |
+
GGML_ASSERT(cls);
|
| 203 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
| 204 |
+
|
| 205 |
+
uint32_t * data = (uint32_t *) cls->data;
|
| 206 |
+
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
| 207 |
+
|
| 208 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 209 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 210 |
+
|
| 211 |
+
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
| 212 |
+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
| 213 |
+
|
| 214 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 215 |
+
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
|
| 216 |
+
|
| 217 |
+
if (pos == 0) {
|
| 218 |
+
data[seq_id] = s*n_seq_tokens + i;
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
| 225 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 226 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 227 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 228 |
+
|
| 229 |
+
GGML_ASSERT(cls);
|
| 230 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
| 231 |
+
|
| 232 |
+
uint32_t * data = (uint32_t *) cls->data;
|
| 233 |
+
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
| 234 |
+
|
| 235 |
+
std::vector<int> last_pos(n_tokens, -1);
|
| 236 |
+
std::vector<int> last_row(n_tokens, -1);
|
| 237 |
+
|
| 238 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 239 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 240 |
+
|
| 241 |
+
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
| 242 |
+
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
| 243 |
+
|
| 244 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 245 |
+
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
|
| 246 |
+
|
| 247 |
+
if (pos >= last_pos[seq_id]) {
|
| 248 |
+
last_pos[seq_id] = pos;
|
| 249 |
+
last_row[seq_id] = s*n_seq_tokens + i;
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 255 |
+
if (last_row[i] >= 0) {
|
| 256 |
+
data[i] = last_row[i];
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
| 263 |
+
GGML_UNUSED(ubatch);
|
| 264 |
+
|
| 265 |
+
const int64_t n_kv = kv_self->n;
|
| 266 |
+
|
| 267 |
+
if (s_copy) {
|
| 268 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
| 269 |
+
int32_t * data = (int32_t *) s_copy->data;
|
| 270 |
+
|
| 271 |
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 272 |
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 273 |
+
const uint32_t cell_id = i + kv_self->head;
|
| 274 |
+
|
| 275 |
+
//////////////////////////////////////////////
|
| 276 |
+
// TODO: this should not mutate the KV cache !
|
| 277 |
+
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
| 278 |
+
|
| 279 |
+
// prevent out-of-bound sources
|
| 280 |
+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
| 281 |
+
kv_cell.src = cell_id;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
data[i] = kv_cell.src;
|
| 285 |
+
|
| 286 |
+
// TODO: do not mutate the KV cache
|
| 287 |
+
// ensure copy only happens once
|
| 288 |
+
if (kv_cell.src != (int32_t) cell_id) {
|
| 289 |
+
kv_cell.src = cell_id;
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
| 296 |
+
GGML_UNUSED(ubatch);
|
| 297 |
+
|
| 298 |
+
const int64_t n_kv = kv_self->n;
|
| 299 |
+
|
| 300 |
+
if (s_mask) {
|
| 301 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
| 302 |
+
float * data = (float *) s_mask->data;
|
| 303 |
+
|
| 304 |
+
// clear unused states
|
| 305 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 306 |
+
const uint32_t cell_id = i + kv_self->head;
|
| 307 |
+
|
| 308 |
+
//////////////////////////////////////////////
|
| 309 |
+
// TODO: this should not mutate the KV cache !
|
| 310 |
+
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
| 311 |
+
|
| 312 |
+
data[i] = (float) (kv_cell.src >= 0);
|
| 313 |
+
|
| 314 |
+
// only clear once
|
| 315 |
+
if (kv_cell.src < 0) {
|
| 316 |
+
kv_cell.src = cell_id;
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
| 323 |
+
GGML_UNUSED(ubatch);
|
| 324 |
+
|
| 325 |
+
if (cross_embd && !cross->v_embd.empty()) {
|
| 326 |
+
assert(cross_embd->type == GGML_TYPE_F32);
|
| 327 |
+
|
| 328 |
+
ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
| 333 |
+
if (kq_mask) {
|
| 334 |
+
if (cparams.causal_attn) {
|
| 335 |
+
const int64_t n_kv = ubatch->n_tokens;
|
| 336 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 337 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 338 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 339 |
+
|
| 340 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
| 341 |
+
float * data = (float *) kq_mask->data;
|
| 342 |
+
|
| 343 |
+
for (int h = 0; h < 1; ++h) {
|
| 344 |
+
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
| 345 |
+
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
| 346 |
+
|
| 347 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 348 |
+
const int32_t tj = s1*n_seq_tokens + j;
|
| 349 |
+
|
| 350 |
+
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
| 351 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 352 |
+
const int32_t ti = s0*n_seq_tokens + i;
|
| 353 |
+
float f = -INFINITY;
|
| 354 |
+
|
| 355 |
+
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 356 |
+
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
| 357 |
+
if (hparams.use_alibi) {
|
| 358 |
+
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
| 359 |
+
} else {
|
| 360 |
+
f = 0.0f;
|
| 361 |
+
}
|
| 362 |
+
break;
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
| 367 |
+
}
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
} else {
|
| 373 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 374 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 375 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 376 |
+
const int64_t n_stride = ubatch->n_tokens;
|
| 377 |
+
|
| 378 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
| 379 |
+
|
| 380 |
+
float * data = (float *) kq_mask->data;
|
| 381 |
+
|
| 382 |
+
for (int h = 0; h < 1; ++h) {
|
| 383 |
+
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
| 384 |
+
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
| 385 |
+
|
| 386 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 387 |
+
const int32_t tj = s1*n_seq_tokens + j;
|
| 388 |
+
|
| 389 |
+
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
| 390 |
+
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 391 |
+
const int32_t ti = s0*n_seq_tokens + i;
|
| 392 |
+
float f = -INFINITY;
|
| 393 |
+
|
| 394 |
+
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 395 |
+
if (ubatch->seq_id[s0][s] == seq_id) {
|
| 396 |
+
if (hparams.use_alibi) {
|
| 397 |
+
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
| 398 |
+
} else {
|
| 399 |
+
f = 0.0f;
|
| 400 |
+
}
|
| 401 |
+
break;
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
for (int i = n_tokens; i < n_stride; ++i) {
|
| 410 |
+
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
}
|
| 416 |
+
}
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 420 |
+
if (self_kq_mask || self_kq_mask_swa) {
|
| 421 |
+
const int64_t n_kv = kv_self->n;
|
| 422 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 423 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 424 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 425 |
+
|
| 426 |
+
float * data = nullptr;
|
| 427 |
+
float * data_swa = nullptr;
|
| 428 |
+
|
| 429 |
+
if (self_kq_mask) {
|
| 430 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
| 431 |
+
data = (float *) self_kq_mask->data;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
if (self_kq_mask_swa) {
|
| 435 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
| 436 |
+
data_swa = (float *) self_kq_mask_swa->data;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 440 |
+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
| 441 |
+
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
| 442 |
+
// Causal mask:
|
| 443 |
+
// xxx-------
|
| 444 |
+
// xxxx------
|
| 445 |
+
// xxxxx-----
|
| 446 |
+
// Non-causal mask:
|
| 447 |
+
// xxxxx-----
|
| 448 |
+
// xxxxx-----
|
| 449 |
+
// xxxxx-----
|
| 450 |
+
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 451 |
+
for (int h = 0; h < 1; ++h) {
|
| 452 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 453 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 454 |
+
|
| 455 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 456 |
+
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
| 457 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 458 |
+
float f;
|
| 459 |
+
// mask the token if:
|
| 460 |
+
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
| 461 |
+
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
| 462 |
+
) {
|
| 463 |
+
f = -INFINITY;
|
| 464 |
+
} else {
|
| 465 |
+
if (hparams.use_alibi) {
|
| 466 |
+
f = -std::abs(kv_self->cells[i].pos - pos);
|
| 467 |
+
} else {
|
| 468 |
+
f = 0.0f;
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
if (data) {
|
| 473 |
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
// may need to cut off old tokens for sliding window
|
| 477 |
+
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
| 478 |
+
if (data_swa) {
|
| 479 |
+
if (hparams.n_attn_chunk) {
|
| 480 |
+
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
| 481 |
+
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
| 482 |
+
f = -INFINITY;
|
| 483 |
+
}
|
| 484 |
+
} else {
|
| 485 |
+
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
| 486 |
+
f = -INFINITY;
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
// mask padded tokens
|
| 496 |
+
if (data) {
|
| 497 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 498 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 499 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
// mask padded tokens
|
| 505 |
+
if (data_swa) {
|
| 506 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 507 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 508 |
+
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
| 517 |
+
if (cross_kq_mask) {
|
| 518 |
+
const int64_t n_enc = cross_kq_mask->ne[0];
|
| 519 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 520 |
+
|
| 521 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
| 522 |
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 523 |
+
|
| 524 |
+
float * data = (float *) cross_kq_mask->data;
|
| 525 |
+
|
| 526 |
+
for (int h = 0; h < 1; ++h) {
|
| 527 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 528 |
+
for (int i = 0; i < n_enc; ++i) {
|
| 529 |
+
float f = -INFINITY;
|
| 530 |
+
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
| 531 |
+
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
| 532 |
+
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
| 533 |
+
f = 0.0f;
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 541 |
+
for (int j = 0; j < n_enc; ++j) {
|
| 542 |
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
| 543 |
+
}
|
| 544 |
+
}
|
| 545 |
+
}
|
| 546 |
+
}
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
//
|
| 550 |
+
// llm_graph_context
|
| 551 |
+
//
|
| 552 |
+
|
| 553 |
+
llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
| 554 |
+
arch (params.arch),
|
| 555 |
+
hparams (params.hparams),
|
| 556 |
+
cparams (params.cparams),
|
| 557 |
+
ubatch (params.ubatch),
|
| 558 |
+
n_embd (hparams.n_embd),
|
| 559 |
+
n_layer (hparams.n_layer),
|
| 560 |
+
n_rot (hparams.n_rot),
|
| 561 |
+
n_ctx (cparams.n_ctx),
|
| 562 |
+
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
|
| 563 |
+
n_head (hparams.n_head()),
|
| 564 |
+
n_head_kv (hparams.n_head_kv()),
|
| 565 |
+
n_embd_head_k (hparams.n_embd_head_k),
|
| 566 |
+
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
| 567 |
+
n_embd_head_v (hparams.n_embd_head_v),
|
| 568 |
+
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
| 569 |
+
n_expert (hparams.n_expert),
|
| 570 |
+
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
|
| 571 |
+
freq_base (cparams.rope_freq_base),
|
| 572 |
+
freq_scale (cparams.rope_freq_scale),
|
| 573 |
+
ext_factor (cparams.yarn_ext_factor),
|
| 574 |
+
attn_factor (cparams.yarn_attn_factor),
|
| 575 |
+
beta_fast (cparams.yarn_beta_fast),
|
| 576 |
+
beta_slow (cparams.yarn_beta_slow),
|
| 577 |
+
norm_eps (hparams.f_norm_eps),
|
| 578 |
+
norm_rms_eps (hparams.f_norm_rms_eps),
|
| 579 |
+
n_tokens (ubatch.n_tokens),
|
| 580 |
+
n_outputs (params.n_outputs),
|
| 581 |
+
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
| 582 |
+
pooling_type (cparams.pooling_type),
|
| 583 |
+
rope_type (hparams.rope_type),
|
| 584 |
+
ctx0 (params.ctx),
|
| 585 |
+
sched (params.sched),
|
| 586 |
+
backend_cpu (params.backend_cpu),
|
| 587 |
+
cvec (params.cvec),
|
| 588 |
+
loras (params.loras),
|
| 589 |
+
memory (params.memory),
|
| 590 |
+
cross (params.cross),
|
| 591 |
+
cb_func (params.cb),
|
| 592 |
+
res (std::make_unique<llm_graph_result>()) {
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
int64_t llm_graph_context::n_pos_per_token() const {
|
| 596 |
+
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
| 600 |
+
if (cb_func) {
|
| 601 |
+
cb_func(ubatch, cur, name, il);
|
| 602 |
+
}
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
ggml_tensor * llm_graph_context::build_cvec(
|
| 606 |
+
ggml_tensor * cur,
|
| 607 |
+
int il) const {
|
| 608 |
+
return cvec->apply_to(ctx0, cur, il);
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
ggml_tensor * llm_graph_context::build_lora_mm(
|
| 612 |
+
ggml_tensor * w,
|
| 613 |
+
ggml_tensor * cur) const {
|
| 614 |
+
ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
| 615 |
+
|
| 616 |
+
for (const auto & lora : *loras) {
|
| 617 |
+
llama_adapter_lora_weight * lw = lora.first->get_weight(w);
|
| 618 |
+
if (lw == nullptr) {
|
| 619 |
+
continue;
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
const float adapter_scale = lora.second;
|
| 623 |
+
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
|
| 624 |
+
|
| 625 |
+
ggml_tensor * ab_cur = ggml_mul_mat(
|
| 626 |
+
ctx0, lw->b,
|
| 627 |
+
ggml_mul_mat(ctx0, lw->a, cur)
|
| 628 |
+
);
|
| 629 |
+
|
| 630 |
+
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
| 631 |
+
res = ggml_add(ctx0, res, ab_cur);
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
return res;
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
ggml_tensor * llm_graph_context::build_lora_mm_id(
|
| 638 |
+
ggml_tensor * w, // ggml_tensor * as
|
| 639 |
+
ggml_tensor * cur, // ggml_tensor * b
|
| 640 |
+
ggml_tensor * ids) const {
|
| 641 |
+
ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
|
| 642 |
+
for (const auto & lora : *loras) {
|
| 643 |
+
llama_adapter_lora_weight * lw = lora.first->get_weight(w);
|
| 644 |
+
if (lw == nullptr) {
|
| 645 |
+
continue;
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
const float alpha = lora.first->alpha;
|
| 649 |
+
const float rank = (float) lw->b->ne[0];
|
| 650 |
+
const float scale = alpha ? lora.second * alpha / rank : lora.second;
|
| 651 |
+
|
| 652 |
+
ggml_tensor * ab_cur = ggml_mul_mat_id(
|
| 653 |
+
ctx0, lw->b,
|
| 654 |
+
ggml_mul_mat_id(ctx0, lw->a, cur, ids),
|
| 655 |
+
ids
|
| 656 |
+
);
|
| 657 |
+
|
| 658 |
+
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
| 659 |
+
res = ggml_add(ctx0, res, ab_cur);
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
return res;
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
ggml_tensor * llm_graph_context::build_norm(
|
| 666 |
+
ggml_tensor * cur,
|
| 667 |
+
ggml_tensor * mw,
|
| 668 |
+
ggml_tensor * mb,
|
| 669 |
+
llm_norm_type type,
|
| 670 |
+
int il) const {
|
| 671 |
+
switch (type) {
|
| 672 |
+
case LLM_NORM: cur = ggml_norm (ctx0, cur, hparams.f_norm_eps); break;
|
| 673 |
+
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
|
| 674 |
+
case LLM_NORM_GROUP:
|
| 675 |
+
{
|
| 676 |
+
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
|
| 677 |
+
cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
|
| 678 |
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[2]);
|
| 679 |
+
} break;
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
if (mw || mb) {
|
| 683 |
+
cb(cur, "norm", il);
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
if (mw) {
|
| 687 |
+
cur = ggml_mul(ctx0, cur, mw);
|
| 688 |
+
if (mb) {
|
| 689 |
+
cb(cur, "norm_w", il);
|
| 690 |
+
}
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
if (mb) {
|
| 694 |
+
cur = ggml_add(ctx0, cur, mb);
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
return cur;
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
ggml_tensor * llm_graph_context::build_ffn(
|
| 701 |
+
ggml_tensor * cur,
|
| 702 |
+
ggml_tensor * up,
|
| 703 |
+
ggml_tensor * up_b,
|
| 704 |
+
ggml_tensor * up_s,
|
| 705 |
+
ggml_tensor * gate,
|
| 706 |
+
ggml_tensor * gate_b,
|
| 707 |
+
ggml_tensor * gate_s,
|
| 708 |
+
ggml_tensor * down,
|
| 709 |
+
ggml_tensor * down_b,
|
| 710 |
+
ggml_tensor * down_s,
|
| 711 |
+
ggml_tensor * act_scales,
|
| 712 |
+
llm_ffn_op_type type_op,
|
| 713 |
+
llm_ffn_gate_type type_gate,
|
| 714 |
+
int il) const {
|
| 715 |
+
ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
|
| 716 |
+
cb(tmp, "ffn_up", il);
|
| 717 |
+
|
| 718 |
+
if (up_b) {
|
| 719 |
+
tmp = ggml_add(ctx0, tmp, up_b);
|
| 720 |
+
cb(tmp, "ffn_up_b", il);
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
if (up_s) {
|
| 724 |
+
tmp = ggml_mul(ctx0, tmp, up_s);
|
| 725 |
+
cb(tmp, "ffn_up_s", il);
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
if (gate) {
|
| 729 |
+
switch (type_gate) {
|
| 730 |
+
case LLM_FFN_SEQ:
|
| 731 |
+
{
|
| 732 |
+
cur = build_lora_mm(gate, tmp);
|
| 733 |
+
cb(cur, "ffn_gate", il);
|
| 734 |
+
} break;
|
| 735 |
+
case LLM_FFN_PAR:
|
| 736 |
+
{
|
| 737 |
+
cur = build_lora_mm(gate, cur);
|
| 738 |
+
cb(cur, "ffn_gate", il);
|
| 739 |
+
} break;
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
if (gate_b) {
|
| 743 |
+
cur = ggml_add(ctx0, cur, gate_b);
|
| 744 |
+
cb(cur, "ffn_gate_b", il);
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
if (gate_s) {
|
| 748 |
+
cur = ggml_mul(ctx0, cur, gate_s);
|
| 749 |
+
cb(cur, "ffn_gate_s", il);
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
} else {
|
| 753 |
+
cur = tmp;
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
switch (type_op) {
|
| 757 |
+
case LLM_FFN_SILU:
|
| 758 |
+
{
|
| 759 |
+
cur = ggml_silu(ctx0, cur);
|
| 760 |
+
cb(cur, "ffn_silu", il);
|
| 761 |
+
} break;
|
| 762 |
+
case LLM_FFN_GELU:
|
| 763 |
+
{
|
| 764 |
+
cur = ggml_gelu(ctx0, cur);
|
| 765 |
+
cb(cur, "ffn_gelu", il);
|
| 766 |
+
if (act_scales != NULL) {
|
| 767 |
+
cur = ggml_div(ctx0, cur, act_scales);
|
| 768 |
+
cb(cur, "ffn_act", il);
|
| 769 |
+
}
|
| 770 |
+
} break;
|
| 771 |
+
case LLM_FFN_RELU:
|
| 772 |
+
{
|
| 773 |
+
cur = ggml_relu(ctx0, cur);
|
| 774 |
+
cb(cur, "ffn_relu", il);
|
| 775 |
+
} break;
|
| 776 |
+
case LLM_FFN_RELU_SQR:
|
| 777 |
+
{
|
| 778 |
+
cur = ggml_relu(ctx0, cur);
|
| 779 |
+
cb(cur, "ffn_relu", il);
|
| 780 |
+
|
| 781 |
+
cur = ggml_sqr(ctx0, cur);
|
| 782 |
+
cb(cur, "ffn_sqr(relu)", il);
|
| 783 |
+
} break;
|
| 784 |
+
case LLM_FFN_SWIGLU:
|
| 785 |
+
{
|
| 786 |
+
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
| 787 |
+
int64_t split_point = cur->ne[0] / 2;
|
| 788 |
+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 789 |
+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 790 |
+
|
| 791 |
+
x0 = ggml_silu(ctx0, x0);
|
| 792 |
+
cb(cur, "ffn_silu", il);
|
| 793 |
+
|
| 794 |
+
cur = ggml_mul(ctx0, x0, x1);
|
| 795 |
+
cb(cur, "ffn_mul", il);
|
| 796 |
+
} break;
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
if (type_gate == LLM_FFN_PAR) {
|
| 800 |
+
cur = ggml_mul(ctx0, cur, tmp);
|
| 801 |
+
cb(cur, "ffn_gate_par", il);
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
if (down) {
|
| 805 |
+
cur = build_lora_mm(down, cur);
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
if (down_b) {
|
| 809 |
+
cb(cur, "ffn_down", il);
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
if (down_b) {
|
| 813 |
+
cur = ggml_add(ctx0, cur, down_b);
|
| 814 |
+
}
|
| 815 |
+
|
| 816 |
+
if (down_s) {
|
| 817 |
+
cur = ggml_mul(ctx0, cur, down_s);
|
| 818 |
+
cb(cur, "ffn_down_s", il);
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
return cur;
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
ggml_tensor * llm_graph_context::build_moe_ffn(
|
| 825 |
+
ggml_tensor * cur,
|
| 826 |
+
ggml_tensor * gate_inp,
|
| 827 |
+
ggml_tensor * up_exps,
|
| 828 |
+
ggml_tensor * gate_exps,
|
| 829 |
+
ggml_tensor * down_exps,
|
| 830 |
+
ggml_tensor * exp_probs_b,
|
| 831 |
+
int64_t n_expert,
|
| 832 |
+
int64_t n_expert_used,
|
| 833 |
+
llm_ffn_op_type type_op,
|
| 834 |
+
bool norm_w,
|
| 835 |
+
bool scale_w,
|
| 836 |
+
float w_scale,
|
| 837 |
+
llama_expert_gating_func_type gating_op,
|
| 838 |
+
int il) const {
|
| 839 |
+
const int64_t n_embd = cur->ne[0];
|
| 840 |
+
const int64_t n_tokens = cur->ne[1];
|
| 841 |
+
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
| 842 |
+
|
| 843 |
+
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
| 844 |
+
cb(logits, "ffn_moe_logits", il);
|
| 845 |
+
|
| 846 |
+
ggml_tensor * probs = nullptr;
|
| 847 |
+
switch (gating_op) {
|
| 848 |
+
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
|
| 849 |
+
{
|
| 850 |
+
probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
|
| 851 |
+
} break;
|
| 852 |
+
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
|
| 853 |
+
{
|
| 854 |
+
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
| 855 |
+
} break;
|
| 856 |
+
default:
|
| 857 |
+
GGML_ABORT("fatal error");
|
| 858 |
+
}
|
| 859 |
+
cb(probs, "ffn_moe_probs", il);
|
| 860 |
+
|
| 861 |
+
// add experts selection bias - introduced in DeepSeek V3
|
| 862 |
+
// leave probs unbiased as it's later used to get expert weights
|
| 863 |
+
ggml_tensor * selection_probs = probs;
|
| 864 |
+
if (exp_probs_b != nullptr) {
|
| 865 |
+
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
|
| 866 |
+
cb(selection_probs, "ffn_moe_probs_biased", il);
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
|
| 870 |
+
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
|
| 871 |
+
if (arch == LLM_ARCH_LLAMA4) {
|
| 872 |
+
selection_probs = logits;
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
// select experts
|
| 876 |
+
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
| 877 |
+
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
| 878 |
+
cb(selected_experts, "ffn_moe_topk", il);
|
| 879 |
+
|
| 880 |
+
ggml_tensor * weights = ggml_get_rows(ctx0,
|
| 881 |
+
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
| 882 |
+
cb(weights, "ffn_moe_weights", il);
|
| 883 |
+
|
| 884 |
+
if (norm_w) {
|
| 885 |
+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
| 886 |
+
|
| 887 |
+
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
| 888 |
+
cb(weights_sum, "ffn_moe_weights_sum", il);
|
| 889 |
+
|
| 890 |
+
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
| 891 |
+
cb(weights, "ffn_moe_weights_norm", il);
|
| 892 |
+
|
| 893 |
+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
| 894 |
+
}
|
| 895 |
+
if (scale_w) {
|
| 896 |
+
weights = ggml_scale(ctx0, weights, w_scale);
|
| 897 |
+
cb(weights, "ffn_moe_weights_scaled", il);
|
| 898 |
+
}
|
| 899 |
+
|
| 900 |
+
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
| 901 |
+
|
| 902 |
+
if (weight_before_ffn) {
|
| 903 |
+
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
|
| 904 |
+
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
|
| 905 |
+
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
| 906 |
+
cur = ggml_mul(ctx0, repeated, weights);
|
| 907 |
+
cb(cur, "ffn_moe_weighted", il);
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
| 911 |
+
cb(up, "ffn_moe_up", il);
|
| 912 |
+
|
| 913 |
+
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
| 914 |
+
cb(gate, "ffn_moe_gate", il);
|
| 915 |
+
|
| 916 |
+
switch (type_op) {
|
| 917 |
+
case LLM_FFN_SILU:
|
| 918 |
+
{
|
| 919 |
+
gate = ggml_silu(ctx0, gate);
|
| 920 |
+
cb(gate, "ffn_moe_silu", il);
|
| 921 |
+
} break;
|
| 922 |
+
case LLM_FFN_GELU:
|
| 923 |
+
{
|
| 924 |
+
gate = ggml_gelu(ctx0, gate);
|
| 925 |
+
cb(gate, "ffn_moe_gelu", il);
|
| 926 |
+
} break;
|
| 927 |
+
default:
|
| 928 |
+
GGML_ABORT("fatal error");
|
| 929 |
+
}
|
| 930 |
+
|
| 931 |
+
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
|
| 932 |
+
cb(par, "ffn_moe_gate_par", il);
|
| 933 |
+
|
| 934 |
+
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
| 935 |
+
cb(experts, "ffn_moe_down", il);
|
| 936 |
+
|
| 937 |
+
if (!weight_before_ffn) {
|
| 938 |
+
experts = ggml_mul(ctx0, experts, weights);
|
| 939 |
+
cb(cur, "ffn_moe_weighted", il);
|
| 940 |
+
}
|
| 941 |
+
|
| 942 |
+
// aggregate experts
|
| 943 |
+
ggml_tensor * moe_out = nullptr;
|
| 944 |
+
for (int i = 0; i < n_expert_used; ++i) {
|
| 945 |
+
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
|
| 946 |
+
experts->nb[2], i*experts->nb[1]);
|
| 947 |
+
|
| 948 |
+
if (i == 0) {
|
| 949 |
+
moe_out = cur_expert;
|
| 950 |
+
} else {
|
| 951 |
+
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
| 952 |
+
}
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
if (n_expert_used == 1) {
|
| 956 |
+
// avoid returning a non-contiguous tensor
|
| 957 |
+
moe_out = ggml_cont(ctx0, moe_out);
|
| 958 |
+
}
|
| 959 |
+
|
| 960 |
+
cb(moe_out, "ffn_moe_out", il);
|
| 961 |
+
|
| 962 |
+
return moe_out;
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
// input embeddings with optional lora
|
| 966 |
+
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
| 967 |
+
const int64_t n_embd = hparams.n_embd;
|
| 968 |
+
|
| 969 |
+
auto inp = std::make_unique<llm_graph_input_embd>();
|
| 970 |
+
|
| 971 |
+
ggml_tensor * cur = nullptr;
|
| 972 |
+
|
| 973 |
+
if (ubatch.token) {
|
| 974 |
+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
| 975 |
+
//cb(inp->tokens, "inp_tokens", -1);
|
| 976 |
+
ggml_set_input(inp->tokens);
|
| 977 |
+
|
| 978 |
+
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
| 979 |
+
|
| 980 |
+
// apply lora for embedding tokens if needed
|
| 981 |
+
for (const auto & lora : *loras) {
|
| 982 |
+
llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
|
| 983 |
+
if (lw == nullptr) {
|
| 984 |
+
continue;
|
| 985 |
+
}
|
| 986 |
+
|
| 987 |
+
const float adapter_scale = lora.second;
|
| 988 |
+
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
|
| 989 |
+
|
| 990 |
+
ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
|
| 991 |
+
ctx0, lw->b, // non-transposed lora_b
|
| 992 |
+
ggml_get_rows(ctx0, lw->a, inp->tokens)
|
| 993 |
+
), scale);
|
| 994 |
+
|
| 995 |
+
cur = ggml_add(ctx0, cur, inpL_delta);
|
| 996 |
+
}
|
| 997 |
+
} else {
|
| 998 |
+
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
| 999 |
+
ggml_set_input(inp->embd);
|
| 1000 |
+
|
| 1001 |
+
cur = inp->embd;
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
// For Granite architecture
|
| 1005 |
+
if (hparams.f_embedding_scale != 0.0f) {
|
| 1006 |
+
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
|
| 1007 |
+
}
|
| 1008 |
+
|
| 1009 |
+
cb(cur, "inp_embd", -1);
|
| 1010 |
+
|
| 1011 |
+
res->add_input(std::move(inp));
|
| 1012 |
+
|
| 1013 |
+
return cur;
|
| 1014 |
+
}
|
| 1015 |
+
|
| 1016 |
+
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
| 1017 |
+
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
|
| 1018 |
+
|
| 1019 |
+
auto & cur = inp->pos;
|
| 1020 |
+
|
| 1021 |
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
|
| 1022 |
+
ggml_set_input(cur);
|
| 1023 |
+
|
| 1024 |
+
res->add_input(std::move(inp));
|
| 1025 |
+
|
| 1026 |
+
return cur;
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
| 1030 |
+
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
| 1031 |
+
|
| 1032 |
+
auto & cur = inp->attn_scale;
|
| 1033 |
+
|
| 1034 |
+
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
|
| 1035 |
+
ggml_set_input(cur);
|
| 1036 |
+
|
| 1037 |
+
res->add_input(std::move(inp));
|
| 1038 |
+
|
| 1039 |
+
return cur;
|
| 1040 |
+
}
|
| 1041 |
+
|
| 1042 |
+
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
| 1043 |
+
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
| 1044 |
+
|
| 1045 |
+
auto & cur = inp->out_ids;
|
| 1046 |
+
|
| 1047 |
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
|
| 1048 |
+
ggml_set_input(cur);
|
| 1049 |
+
|
| 1050 |
+
res->add_input(std::move(inp));
|
| 1051 |
+
|
| 1052 |
+
return cur;
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
+
ggml_tensor * llm_graph_context::build_inp_mean() const {
|
| 1056 |
+
auto inp = std::make_unique<llm_graph_input_mean>(cparams);
|
| 1057 |
+
|
| 1058 |
+
auto & cur = inp->mean;
|
| 1059 |
+
|
| 1060 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
| 1061 |
+
ggml_set_input(cur);
|
| 1062 |
+
|
| 1063 |
+
res->add_input(std::move(inp));
|
| 1064 |
+
|
| 1065 |
+
return cur;
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
ggml_tensor * llm_graph_context::build_inp_cls() const {
|
| 1069 |
+
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
|
| 1070 |
+
|
| 1071 |
+
auto & cur = inp->cls;
|
| 1072 |
+
|
| 1073 |
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
| 1074 |
+
ggml_set_input(cur);
|
| 1075 |
+
|
| 1076 |
+
res->add_input(std::move(inp));
|
| 1077 |
+
|
| 1078 |
+
return cur;
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
+
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
| 1082 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1083 |
+
|
| 1084 |
+
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
| 1085 |
+
|
| 1086 |
+
const auto n_kv = kv_self->n;
|
| 1087 |
+
|
| 1088 |
+
auto & cur = inp->s_copy;
|
| 1089 |
+
|
| 1090 |
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
| 1091 |
+
ggml_set_input(cur);
|
| 1092 |
+
|
| 1093 |
+
res->add_input(std::move(inp));
|
| 1094 |
+
|
| 1095 |
+
return cur;
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
| 1099 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1100 |
+
|
| 1101 |
+
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
| 1102 |
+
|
| 1103 |
+
const auto n_kv = kv_self->n;
|
| 1104 |
+
|
| 1105 |
+
auto & cur = inp->s_mask;
|
| 1106 |
+
|
| 1107 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
| 1108 |
+
ggml_set_input(cur);
|
| 1109 |
+
|
| 1110 |
+
res->add_input(std::move(inp));
|
| 1111 |
+
|
| 1112 |
+
return cur;
|
| 1113 |
+
}
|
| 1114 |
+
|
| 1115 |
+
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
| 1116 |
+
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
| 1117 |
+
|
| 1118 |
+
auto & cur = inp->cross_embd;
|
| 1119 |
+
|
| 1120 |
+
// if we have the output embeddings from the encoder, use them directly
|
| 1121 |
+
// TODO: needs more work to be correct, for now just use the tensor shape
|
| 1122 |
+
//if (cross->t_embd) {
|
| 1123 |
+
// cur = ggml_view_tensor(ctx0, cross->t_embd);
|
| 1124 |
+
|
| 1125 |
+
// return cur;
|
| 1126 |
+
//}
|
| 1127 |
+
|
| 1128 |
+
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
|
| 1129 |
+
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
| 1130 |
+
|
| 1131 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
| 1132 |
+
ggml_set_input(cur);
|
| 1133 |
+
|
| 1134 |
+
res->add_input(std::move(inp));
|
| 1135 |
+
|
| 1136 |
+
return cur;
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
| 1140 |
+
auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
|
| 1141 |
+
|
| 1142 |
+
auto & cur = inp->pos_bucket;
|
| 1143 |
+
|
| 1144 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
|
| 1145 |
+
ggml_set_input(cur);
|
| 1146 |
+
|
| 1147 |
+
res->add_input(std::move(inp));
|
| 1148 |
+
|
| 1149 |
+
return cur;
|
| 1150 |
+
}
|
| 1151 |
+
|
| 1152 |
+
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
| 1153 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1154 |
+
|
| 1155 |
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
| 1156 |
+
|
| 1157 |
+
const auto n_kv = kv_self->n;
|
| 1158 |
+
|
| 1159 |
+
auto & cur = inp->pos_bucket;
|
| 1160 |
+
|
| 1161 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
|
| 1162 |
+
ggml_set_input(cur);
|
| 1163 |
+
|
| 1164 |
+
res->add_input(std::move(inp));
|
| 1165 |
+
|
| 1166 |
+
return cur;
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
|
| 1170 |
+
ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
|
| 1171 |
+
cb(pos_bucket_1d, "pos_bucket_1d", -1);
|
| 1172 |
+
|
| 1173 |
+
ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
|
| 1174 |
+
|
| 1175 |
+
pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
|
| 1176 |
+
pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
|
| 1177 |
+
pos_bias = ggml_cont (ctx0, pos_bias);
|
| 1178 |
+
|
| 1179 |
+
cb(pos_bias, "pos_bias", -1);
|
| 1180 |
+
|
| 1181 |
+
return pos_bias;
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
ggml_tensor * llm_graph_context::build_attn_mha(
|
| 1185 |
+
ggml_cgraph * gf,
|
| 1186 |
+
ggml_tensor * q,
|
| 1187 |
+
ggml_tensor * k,
|
| 1188 |
+
ggml_tensor * v,
|
| 1189 |
+
ggml_tensor * kq_b,
|
| 1190 |
+
ggml_tensor * kq_mask,
|
| 1191 |
+
ggml_tensor * v_mla,
|
| 1192 |
+
bool v_trans,
|
| 1193 |
+
float kq_scale) const {
|
| 1194 |
+
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1195 |
+
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1196 |
+
|
| 1197 |
+
//const int64_t n_head = hparams.n_head(il);
|
| 1198 |
+
//const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 1199 |
+
|
| 1200 |
+
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 1201 |
+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 1202 |
+
|
| 1203 |
+
const auto n_tokens = q->ne[1];
|
| 1204 |
+
const auto n_head = q->ne[2];
|
| 1205 |
+
const auto n_kv = k->ne[1];
|
| 1206 |
+
|
| 1207 |
+
ggml_tensor * cur;
|
| 1208 |
+
|
| 1209 |
+
// TODO: replace hardcoded padding with ggml-provided padding
|
| 1210 |
+
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
| 1211 |
+
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
| 1212 |
+
|
| 1213 |
+
if (v_trans) {
|
| 1214 |
+
v = ggml_transpose(ctx0, v);
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
|
| 1218 |
+
if (k->type == GGML_TYPE_F32) {
|
| 1219 |
+
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
| 1220 |
+
}
|
| 1221 |
+
|
| 1222 |
+
if (v->type == GGML_TYPE_F32) {
|
| 1223 |
+
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
| 1224 |
+
}
|
| 1225 |
+
|
| 1226 |
+
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
| 1227 |
+
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
| 1228 |
+
|
| 1229 |
+
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
| 1230 |
+
|
| 1231 |
+
if (v_mla) {
|
| 1232 |
+
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
| 1233 |
+
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
| 1234 |
+
}
|
| 1235 |
+
|
| 1236 |
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
| 1237 |
+
} else {
|
| 1238 |
+
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
| 1239 |
+
|
| 1240 |
+
// note: this op tends to require high floating point range
|
| 1241 |
+
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
| 1242 |
+
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
| 1243 |
+
|
| 1244 |
+
if (arch == LLM_ARCH_GROK) {
|
| 1245 |
+
// need to do the following:
|
| 1246 |
+
// multiply by attn_output_multiplyer of 0.08838834764831845
|
| 1247 |
+
// and then :
|
| 1248 |
+
// kq = 30 * tanh(kq / 30)
|
| 1249 |
+
// before the softmax below
|
| 1250 |
+
|
| 1251 |
+
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
|
| 1252 |
+
kq = ggml_scale(ctx0, kq, 30);
|
| 1253 |
+
}
|
| 1254 |
+
|
| 1255 |
+
if (hparams.attn_soft_cap) {
|
| 1256 |
+
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
| 1257 |
+
kq = ggml_tanh (ctx0, kq);
|
| 1258 |
+
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
| 1259 |
+
}
|
| 1260 |
+
|
| 1261 |
+
if (kq_b) {
|
| 1262 |
+
kq = ggml_add(ctx0, kq, kq_b);
|
| 1263 |
+
}
|
| 1264 |
+
|
| 1265 |
+
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
| 1266 |
+
|
| 1267 |
+
if (!v_trans) {
|
| 1268 |
+
// note: avoid this branch
|
| 1269 |
+
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
| 1270 |
+
}
|
| 1271 |
+
|
| 1272 |
+
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
| 1273 |
+
|
| 1274 |
+
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
| 1275 |
+
if (v_mla) {
|
| 1276 |
+
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
| 1280 |
+
|
| 1281 |
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
| 1282 |
+
|
| 1283 |
+
if (!cparams.offload_kqv) {
|
| 1284 |
+
// all nodes between the KV store and the attention output are run on the CPU
|
| 1285 |
+
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
|
| 1286 |
+
}
|
| 1287 |
+
}
|
| 1288 |
+
|
| 1289 |
+
ggml_build_forward_expand(gf, cur);
|
| 1290 |
+
|
| 1291 |
+
return cur;
|
| 1292 |
+
}
|
| 1293 |
+
|
| 1294 |
+
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
|
| 1295 |
+
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
| 1296 |
+
|
| 1297 |
+
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
| 1298 |
+
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1299 |
+
//cb(inp_kq_mask, "KQ_mask", -1);
|
| 1300 |
+
ggml_set_input(inp->kq_mask);
|
| 1301 |
+
|
| 1302 |
+
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
| 1303 |
+
|
| 1304 |
+
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
| 1305 |
+
}
|
| 1306 |
+
|
| 1307 |
+
ggml_tensor * llm_graph_context::build_attn(
|
| 1308 |
+
llm_graph_input_attn_no_cache * inp,
|
| 1309 |
+
ggml_cgraph * gf,
|
| 1310 |
+
ggml_tensor * wo,
|
| 1311 |
+
ggml_tensor * wo_b,
|
| 1312 |
+
ggml_tensor * q_cur,
|
| 1313 |
+
ggml_tensor * k_cur,
|
| 1314 |
+
ggml_tensor * v_cur,
|
| 1315 |
+
ggml_tensor * kq_b,
|
| 1316 |
+
ggml_tensor * v_mla,
|
| 1317 |
+
float kq_scale,
|
| 1318 |
+
int il) const {
|
| 1319 |
+
GGML_UNUSED(n_tokens);
|
| 1320 |
+
|
| 1321 |
+
// these nodes are added to the graph together so that they are not reordered
|
| 1322 |
+
// by doing so, the number of splits in the graph is reduced
|
| 1323 |
+
ggml_build_forward_expand(gf, q_cur);
|
| 1324 |
+
ggml_build_forward_expand(gf, k_cur);
|
| 1325 |
+
ggml_build_forward_expand(gf, v_cur);
|
| 1326 |
+
|
| 1327 |
+
const auto & kq_mask = inp->get_kq_mask();
|
| 1328 |
+
|
| 1329 |
+
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
| 1330 |
+
//cb(q, "q", il);
|
| 1331 |
+
|
| 1332 |
+
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
| 1333 |
+
//cb(k, "k", il);
|
| 1334 |
+
|
| 1335 |
+
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
| 1336 |
+
//cb(k, "v", il);
|
| 1337 |
+
|
| 1338 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
| 1339 |
+
|
| 1340 |
+
cb(cur, "kqv_out", il);
|
| 1341 |
+
|
| 1342 |
+
if (wo) {
|
| 1343 |
+
cur = build_lora_mm(wo, cur);
|
| 1344 |
+
}
|
| 1345 |
+
|
| 1346 |
+
if (wo_b) {
|
| 1347 |
+
//cb(cur, "kqv_wo", il);
|
| 1348 |
+
}
|
| 1349 |
+
|
| 1350 |
+
if (wo_b) {
|
| 1351 |
+
cur = ggml_add(ctx0, cur, wo_b);
|
| 1352 |
+
}
|
| 1353 |
+
|
| 1354 |
+
return cur;
|
| 1355 |
+
}
|
| 1356 |
+
|
| 1357 |
+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
| 1358 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1359 |
+
|
| 1360 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
| 1361 |
+
|
| 1362 |
+
const auto n_kv = kv_self->n;
|
| 1363 |
+
|
| 1364 |
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1365 |
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1366 |
+
ggml_set_input(inp->self_kq_mask);
|
| 1367 |
+
|
| 1368 |
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1369 |
+
|
| 1370 |
+
if (hparams.n_swa_pattern > 1) {
|
| 1371 |
+
GGML_ASSERT(hparams.n_swa > 0);
|
| 1372 |
+
|
| 1373 |
+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1374 |
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
| 1375 |
+
ggml_set_input(inp->self_kq_mask_swa);
|
| 1376 |
+
|
| 1377 |
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
| 1378 |
+
}
|
| 1379 |
+
|
| 1380 |
+
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
ggml_tensor * llm_graph_context::build_attn(
|
| 1384 |
+
llm_graph_input_attn_kv_unified * inp,
|
| 1385 |
+
ggml_cgraph * gf,
|
| 1386 |
+
ggml_tensor * wo,
|
| 1387 |
+
ggml_tensor * wo_b,
|
| 1388 |
+
ggml_tensor * q_cur,
|
| 1389 |
+
ggml_tensor * k_cur,
|
| 1390 |
+
ggml_tensor * v_cur,
|
| 1391 |
+
ggml_tensor * kq_b,
|
| 1392 |
+
ggml_tensor * v_mla,
|
| 1393 |
+
float kq_scale,
|
| 1394 |
+
int il) const {
|
| 1395 |
+
// these nodes are added to the graph together so that they are not reordered
|
| 1396 |
+
// by doing so, the number of splits in the graph is reduced
|
| 1397 |
+
ggml_build_forward_expand(gf, q_cur);
|
| 1398 |
+
ggml_build_forward_expand(gf, k_cur);
|
| 1399 |
+
ggml_build_forward_expand(gf, v_cur);
|
| 1400 |
+
|
| 1401 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1402 |
+
const auto & n_ctx = cparams.n_ctx;
|
| 1403 |
+
|
| 1404 |
+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1405 |
+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1406 |
+
|
| 1407 |
+
const auto n_tokens = q_cur->ne[2];
|
| 1408 |
+
|
| 1409 |
+
const bool v_trans = !cparams.flash_attn;
|
| 1410 |
+
|
| 1411 |
+
// store to KV cache
|
| 1412 |
+
{
|
| 1413 |
+
GGML_ASSERT(!kv_self->recurrent);
|
| 1414 |
+
|
| 1415 |
+
const auto kv_head = kv_self->head;
|
| 1416 |
+
|
| 1417 |
+
GGML_ASSERT(kv_self->size == n_ctx);
|
| 1418 |
+
|
| 1419 |
+
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
|
| 1420 |
+
//cb(k_cache_view, "k_cache_view", il);
|
| 1421 |
+
|
| 1422 |
+
// note: storing RoPE-ed version of K in the KV cache
|
| 1423 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
| 1424 |
+
|
| 1425 |
+
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
|
| 1426 |
+
|
| 1427 |
+
ggml_tensor * v_cache_view = nullptr;
|
| 1428 |
+
|
| 1429 |
+
if (!v_trans) {
|
| 1430 |
+
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
|
| 1431 |
+
} else {
|
| 1432 |
+
// note: the V cache is transposed when not using flash attention
|
| 1433 |
+
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
| 1434 |
+
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
|
| 1435 |
+
(kv_head)*ggml_element_size(kv_self->v_l[il]));
|
| 1436 |
+
|
| 1437 |
+
v_cur = ggml_transpose(ctx0, v_cur);
|
| 1438 |
+
}
|
| 1439 |
+
//cb(v_cache_view, "v_cache_view", il);
|
| 1440 |
+
|
| 1441 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
| 1442 |
+
}
|
| 1443 |
+
|
| 1444 |
+
const bool is_swa = hparams.is_swa(il);
|
| 1445 |
+
|
| 1446 |
+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1447 |
+
|
| 1448 |
+
const auto n_kv = kv_self->n;
|
| 1449 |
+
|
| 1450 |
+
const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 1451 |
+
|
| 1452 |
+
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 1453 |
+
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 1454 |
+
|
| 1455 |
+
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
| 1456 |
+
//cb(q, "q", il);
|
| 1457 |
+
|
| 1458 |
+
ggml_tensor * k =
|
| 1459 |
+
ggml_view_3d(ctx0, kv_self->k_l[il],
|
| 1460 |
+
n_embd_head_k, n_kv, n_head_kv,
|
| 1461 |
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
| 1462 |
+
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
| 1463 |
+
0);
|
| 1464 |
+
//cb(k, "k", il);
|
| 1465 |
+
|
| 1466 |
+
ggml_tensor * v = !v_trans ?
|
| 1467 |
+
ggml_view_3d(ctx0, kv_self->v_l[il],
|
| 1468 |
+
n_embd_head_v, n_kv, n_head_kv,
|
| 1469 |
+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
| 1470 |
+
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
| 1471 |
+
0) :
|
| 1472 |
+
ggml_view_3d(ctx0, kv_self->v_l[il],
|
| 1473 |
+
n_kv, n_embd_head_v, n_head_kv,
|
| 1474 |
+
ggml_element_size(kv_self->v_l[il])*n_ctx,
|
| 1475 |
+
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
| 1476 |
+
0);
|
| 1477 |
+
|
| 1478 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
| 1479 |
+
cb(cur, "kqv_out", il);
|
| 1480 |
+
|
| 1481 |
+
if (wo) {
|
| 1482 |
+
cur = build_lora_mm(wo, cur);
|
| 1483 |
+
}
|
| 1484 |
+
|
| 1485 |
+
if (wo_b) {
|
| 1486 |
+
//cb(cur, "kqv_wo", il);
|
| 1487 |
+
}
|
| 1488 |
+
|
| 1489 |
+
if (wo_b) {
|
| 1490 |
+
cur = ggml_add(ctx0, cur, wo_b);
|
| 1491 |
+
}
|
| 1492 |
+
|
| 1493 |
+
return cur;
|
| 1494 |
+
}
|
| 1495 |
+
|
| 1496 |
+
llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
| 1497 |
+
auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
|
| 1498 |
+
|
| 1499 |
+
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
| 1500 |
+
|
| 1501 |
+
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1502 |
+
ggml_set_input(inp->cross_kq_mask);
|
| 1503 |
+
|
| 1504 |
+
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
| 1505 |
+
|
| 1506 |
+
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
|
| 1507 |
+
}
|
| 1508 |
+
|
| 1509 |
+
ggml_tensor * llm_graph_context::build_attn(
|
| 1510 |
+
llm_graph_input_attn_cross * inp,
|
| 1511 |
+
ggml_cgraph * gf,
|
| 1512 |
+
ggml_tensor * wo,
|
| 1513 |
+
ggml_tensor * wo_b,
|
| 1514 |
+
ggml_tensor * q_cur,
|
| 1515 |
+
ggml_tensor * k_cur,
|
| 1516 |
+
ggml_tensor * v_cur,
|
| 1517 |
+
ggml_tensor * kq_b,
|
| 1518 |
+
ggml_tensor * v_mla,
|
| 1519 |
+
float kq_scale,
|
| 1520 |
+
int il) const {
|
| 1521 |
+
// these nodes are added to the graph together so that they are not reordered
|
| 1522 |
+
// by doing so, the number of splits in the graph is reduced
|
| 1523 |
+
ggml_build_forward_expand(gf, q_cur);
|
| 1524 |
+
ggml_build_forward_expand(gf, k_cur);
|
| 1525 |
+
ggml_build_forward_expand(gf, v_cur);
|
| 1526 |
+
|
| 1527 |
+
const auto & kq_mask = inp->get_kq_mask_cross();
|
| 1528 |
+
|
| 1529 |
+
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
| 1530 |
+
//cb(q, "q", il);
|
| 1531 |
+
|
| 1532 |
+
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
| 1533 |
+
//cb(k, "k", il);
|
| 1534 |
+
|
| 1535 |
+
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
| 1536 |
+
//cb(k, "v", il);
|
| 1537 |
+
|
| 1538 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
| 1539 |
+
|
| 1540 |
+
cb(cur, "kqv_out", il);
|
| 1541 |
+
|
| 1542 |
+
if (wo) {
|
| 1543 |
+
cur = build_lora_mm(wo, cur);
|
| 1544 |
+
}
|
| 1545 |
+
|
| 1546 |
+
if (wo_b) {
|
| 1547 |
+
//cb(cur, "kqv_wo", il);
|
| 1548 |
+
}
|
| 1549 |
+
|
| 1550 |
+
if (wo_b) {
|
| 1551 |
+
cur = ggml_add(ctx0, cur, wo_b);
|
| 1552 |
+
}
|
| 1553 |
+
|
| 1554 |
+
return cur;
|
| 1555 |
+
}
|
| 1556 |
+
|
| 1557 |
+
ggml_tensor * llm_graph_context::build_copy_mask_state(
|
| 1558 |
+
ggml_cgraph * gf,
|
| 1559 |
+
ggml_tensor * s,
|
| 1560 |
+
ggml_tensor * state_copy,
|
| 1561 |
+
ggml_tensor * state_mask,
|
| 1562 |
+
int32_t n_state,
|
| 1563 |
+
int32_t n_seqs) const {
|
| 1564 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1565 |
+
|
| 1566 |
+
const auto n_kv = kv_self->n;
|
| 1567 |
+
const auto kv_head = kv_self->head;
|
| 1568 |
+
|
| 1569 |
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
|
| 1570 |
+
|
| 1571 |
+
// copy states
|
| 1572 |
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
| 1573 |
+
// this shrinks the tensors's ne[1] to n_kv
|
| 1574 |
+
states = ggml_get_rows(ctx0, states, state_copy);
|
| 1575 |
+
|
| 1576 |
+
// clear states of sequences which are starting at the beginning of this batch
|
| 1577 |
+
// FIXME: zero-out NANs?
|
| 1578 |
+
states = ggml_mul(ctx0, states, state_mask);
|
| 1579 |
+
|
| 1580 |
+
// copy states which won't be changed further (between n_seqs and n_kv)
|
| 1581 |
+
ggml_build_forward_expand(gf,
|
| 1582 |
+
ggml_cpy(ctx0,
|
| 1583 |
+
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
|
| 1584 |
+
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
|
| 1585 |
+
|
| 1586 |
+
// the part of the states that will be used and modified
|
| 1587 |
+
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
|
| 1588 |
+
}
|
| 1589 |
+
|
| 1590 |
+
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1591 |
+
ggml_cgraph * gf,
|
| 1592 |
+
ggml_tensor * state_copy,
|
| 1593 |
+
ggml_tensor * state_mask,
|
| 1594 |
+
const llama_ubatch & ubatch,
|
| 1595 |
+
int il) const {
|
| 1596 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1597 |
+
|
| 1598 |
+
const auto token_shift_count = hparams.token_shift_count;
|
| 1599 |
+
|
| 1600 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 1601 |
+
|
| 1602 |
+
ggml_tensor * token_shift_all = kv_self->k_l[il];
|
| 1603 |
+
|
| 1604 |
+
ggml_tensor * token_shift = build_copy_mask_state(
|
| 1605 |
+
gf, token_shift_all, state_copy, state_mask,
|
| 1606 |
+
hparams.n_embd_k_s(), n_seqs);
|
| 1607 |
+
|
| 1608 |
+
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
| 1609 |
+
|
| 1610 |
+
return token_shift;
|
| 1611 |
+
}
|
| 1612 |
+
|
| 1613 |
+
ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
| 1614 |
+
ggml_tensor * token_shift,
|
| 1615 |
+
const llama_ubatch & ubatch,
|
| 1616 |
+
int il) const {
|
| 1617 |
+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1618 |
+
|
| 1619 |
+
const auto token_shift_count = hparams.token_shift_count;
|
| 1620 |
+
const auto n_embd = hparams.n_embd;
|
| 1621 |
+
|
| 1622 |
+
const int64_t n_seqs = ubatch.n_seqs;
|
| 1623 |
+
|
| 1624 |
+
const auto kv_head = kv_self->head;
|
| 1625 |
+
|
| 1626 |
+
return ggml_cpy(
|
| 1627 |
+
ctx0,
|
| 1628 |
+
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1629 |
+
ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
|
| 1630 |
+
);
|
| 1631 |
+
}
|
| 1632 |
+
|
| 1633 |
+
void llm_graph_context::build_pooling(
|
| 1634 |
+
ggml_cgraph * gf,
|
| 1635 |
+
ggml_tensor * cls,
|
| 1636 |
+
ggml_tensor * cls_b,
|
| 1637 |
+
ggml_tensor * cls_out,
|
| 1638 |
+
ggml_tensor * cls_out_b) const {
|
| 1639 |
+
if (!cparams.embeddings) {
|
| 1640 |
+
return;
|
| 1641 |
+
}
|
| 1642 |
+
|
| 1643 |
+
ggml_tensor * inp = res->t_embd;
|
| 1644 |
+
|
| 1645 |
+
//// find result_norm tensor for input
|
| 1646 |
+
//for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
|
| 1647 |
+
// inp = ggml_graph_node(gf, i);
|
| 1648 |
+
// if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
|
| 1649 |
+
// break;
|
| 1650 |
+
// }
|
| 1651 |
+
|
| 1652 |
+
// inp = nullptr;
|
| 1653 |
+
//}
|
| 1654 |
+
|
| 1655 |
+
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
|
| 1656 |
+
|
| 1657 |
+
ggml_tensor * cur;
|
| 1658 |
+
|
| 1659 |
+
switch (pooling_type) {
|
| 1660 |
+
case LLAMA_POOLING_TYPE_NONE:
|
| 1661 |
+
{
|
| 1662 |
+
cur = inp;
|
| 1663 |
+
} break;
|
| 1664 |
+
case LLAMA_POOLING_TYPE_MEAN:
|
| 1665 |
+
{
|
| 1666 |
+
ggml_tensor * inp_mean = build_inp_mean();
|
| 1667 |
+
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
| 1668 |
+
} break;
|
| 1669 |
+
case LLAMA_POOLING_TYPE_CLS:
|
| 1670 |
+
case LLAMA_POOLING_TYPE_LAST:
|
| 1671 |
+
{
|
| 1672 |
+
ggml_tensor * inp_cls = build_inp_cls();
|
| 1673 |
+
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
| 1674 |
+
} break;
|
| 1675 |
+
case LLAMA_POOLING_TYPE_RANK:
|
| 1676 |
+
{
|
| 1677 |
+
ggml_tensor * inp_cls = build_inp_cls();
|
| 1678 |
+
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
| 1679 |
+
|
| 1680 |
+
// classification head
|
| 1681 |
+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
| 1682 |
+
GGML_ASSERT(cls != nullptr);
|
| 1683 |
+
GGML_ASSERT(cls_b != nullptr);
|
| 1684 |
+
|
| 1685 |
+
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
| 1686 |
+
cur = ggml_tanh(ctx0, cur);
|
| 1687 |
+
|
| 1688 |
+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
| 1689 |
+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
| 1690 |
+
if (cls_out) {
|
| 1691 |
+
GGML_ASSERT(cls_out_b != nullptr);
|
| 1692 |
+
|
| 1693 |
+
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
| 1694 |
+
}
|
| 1695 |
+
} break;
|
| 1696 |
+
default:
|
| 1697 |
+
{
|
| 1698 |
+
GGML_ABORT("unknown pooling type");
|
| 1699 |
+
}
|
| 1700 |
+
}
|
| 1701 |
+
|
| 1702 |
+
cb(cur, "result_embd_pooled", -1);
|
| 1703 |
+
res->t_embd_pooled = cur;
|
| 1704 |
+
|
| 1705 |
+
ggml_build_forward_expand(gf, cur);
|
| 1706 |
+
}
|
examples/talk-llama/llama-graph.h
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama-arch.h"
|
| 4 |
+
#include "llama-hparams.h"
|
| 5 |
+
#include "llama-adapter.h"
|
| 6 |
+
|
| 7 |
+
#include <cstdint>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <memory>
|
| 10 |
+
#include <set>
|
| 11 |
+
#include <functional>
|
| 12 |
+
|
| 13 |
+
struct ggml_cgraph;
|
| 14 |
+
struct ggml_context;
|
| 15 |
+
struct ggml_tensor;
|
| 16 |
+
|
| 17 |
+
struct llama_ubatch;
|
| 18 |
+
struct llama_cparams;
|
| 19 |
+
|
| 20 |
+
class llama_memory_i;
|
| 21 |
+
class llama_kv_cache_unified;
|
| 22 |
+
|
| 23 |
+
// certain models (typically multi-modal) can produce different types of graphs
|
| 24 |
+
enum llm_graph_type {
|
| 25 |
+
LLM_GRAPH_TYPE_DEFAULT,
|
| 26 |
+
LLM_GRAPH_TYPE_ENCODER,
|
| 27 |
+
LLM_GRAPH_TYPE_DECODER,
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
enum llm_ffn_op_type {
|
| 31 |
+
LLM_FFN_SILU,
|
| 32 |
+
LLM_FFN_GELU,
|
| 33 |
+
LLM_FFN_RELU,
|
| 34 |
+
LLM_FFN_RELU_SQR,
|
| 35 |
+
LLM_FFN_SWIGLU,
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
enum llm_ffn_gate_type {
|
| 39 |
+
LLM_FFN_SEQ,
|
| 40 |
+
LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
enum llm_norm_type {
|
| 44 |
+
LLM_NORM,
|
| 45 |
+
LLM_NORM_RMS,
|
| 46 |
+
LLM_NORM_GROUP,
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
| 50 |
+
struct llama_cross {
|
| 51 |
+
// the output embeddings from the encoder as a ggml tensor
|
| 52 |
+
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
|
| 53 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
|
| 54 |
+
//ggml_tensor * t_embd = nullptr;
|
| 55 |
+
|
| 56 |
+
int64_t n_embd = 0;
|
| 57 |
+
int64_t n_enc = 0;
|
| 58 |
+
|
| 59 |
+
// embeddings data copied to host memory (tmp)
|
| 60 |
+
std::vector<float> v_embd;
|
| 61 |
+
|
| 62 |
+
// needed to construct the cross-attention mask in the decoder
|
| 63 |
+
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
//
|
| 67 |
+
// llm_graph_input
|
| 68 |
+
//
|
| 69 |
+
|
| 70 |
+
class llm_graph_input_i {
|
| 71 |
+
public:
|
| 72 |
+
virtual ~llm_graph_input_i() = default;
|
| 73 |
+
|
| 74 |
+
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
| 75 |
+
};
|
| 76 |
+
|
| 77 |
+
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class llm_graph_input_embd : public llm_graph_input_i {
|
| 81 |
+
public:
|
| 82 |
+
llm_graph_input_embd() = default;
|
| 83 |
+
virtual ~llm_graph_input_embd() = default;
|
| 84 |
+
|
| 85 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 86 |
+
|
| 87 |
+
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
| 88 |
+
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
class llm_graph_input_pos : public llm_graph_input_i {
|
| 92 |
+
public:
|
| 93 |
+
llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
|
| 94 |
+
virtual ~llm_graph_input_pos() = default;
|
| 95 |
+
|
| 96 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 97 |
+
|
| 98 |
+
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
| 99 |
+
|
| 100 |
+
const int64_t n_pos_per_token = 1;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
// temperature tuning, used by llama4
|
| 104 |
+
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
| 105 |
+
public:
|
| 106 |
+
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
| 107 |
+
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
| 108 |
+
virtual ~llm_graph_input_attn_temp() = default;
|
| 109 |
+
|
| 110 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 111 |
+
|
| 112 |
+
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
|
| 113 |
+
|
| 114 |
+
const int64_t n_pos_per_token = 1;
|
| 115 |
+
|
| 116 |
+
const uint32_t n_attn_temp_floor_scale;
|
| 117 |
+
const float f_attn_temp_scale;
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
| 121 |
+
public:
|
| 122 |
+
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
|
| 123 |
+
virtual ~llm_graph_input_pos_bucket() = default;
|
| 124 |
+
|
| 125 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 126 |
+
|
| 127 |
+
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
|
| 128 |
+
|
| 129 |
+
const llama_hparams & hparams;
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
| 133 |
+
public:
|
| 134 |
+
llm_graph_input_pos_bucket_kv(
|
| 135 |
+
const llama_hparams & hparams,
|
| 136 |
+
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
|
| 137 |
+
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
| 138 |
+
|
| 139 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 140 |
+
|
| 141 |
+
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
| 142 |
+
|
| 143 |
+
const llama_hparams & hparams;
|
| 144 |
+
const llama_kv_cache_unified * kv_self;
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
class llm_graph_input_out_ids : public llm_graph_input_i {
|
| 148 |
+
public:
|
| 149 |
+
llm_graph_input_out_ids(
|
| 150 |
+
const llama_hparams & hparams,
|
| 151 |
+
const llama_cparams & cparams,
|
| 152 |
+
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
|
| 153 |
+
virtual ~llm_graph_input_out_ids() = default;
|
| 154 |
+
|
| 155 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 156 |
+
|
| 157 |
+
ggml_tensor * out_ids; // I32 [n_outputs]
|
| 158 |
+
|
| 159 |
+
const llama_hparams & hparams;
|
| 160 |
+
const llama_cparams & cparams;
|
| 161 |
+
|
| 162 |
+
const int32_t n_outputs;
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
class llm_graph_input_mean : public llm_graph_input_i {
|
| 166 |
+
public:
|
| 167 |
+
llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
|
| 168 |
+
virtual ~llm_graph_input_mean() = default;
|
| 169 |
+
|
| 170 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 171 |
+
|
| 172 |
+
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
| 173 |
+
|
| 174 |
+
const llama_cparams & cparams;
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
class llm_graph_input_cls : public llm_graph_input_i {
|
| 178 |
+
public:
|
| 179 |
+
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
|
| 180 |
+
virtual ~llm_graph_input_cls() = default;
|
| 181 |
+
|
| 182 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 183 |
+
|
| 184 |
+
ggml_tensor * cls; // I32 [n_batch]
|
| 185 |
+
|
| 186 |
+
const llama_cparams & cparams;
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
class llm_graph_input_s_copy : public llm_graph_input_i {
|
| 190 |
+
public:
|
| 191 |
+
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
| 192 |
+
virtual ~llm_graph_input_s_copy() = default;
|
| 193 |
+
|
| 194 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 195 |
+
|
| 196 |
+
ggml_tensor * s_copy; // I32 [kv_size]
|
| 197 |
+
|
| 198 |
+
const llama_kv_cache_unified * kv_self;
|
| 199 |
+
};
|
| 200 |
+
|
| 201 |
+
class llm_graph_input_s_mask : public llm_graph_input_i {
|
| 202 |
+
public:
|
| 203 |
+
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
| 204 |
+
virtual ~llm_graph_input_s_mask() = default;
|
| 205 |
+
|
| 206 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 207 |
+
|
| 208 |
+
ggml_tensor * s_mask; // F32 [1, n_kv]
|
| 209 |
+
|
| 210 |
+
const llama_kv_cache_unified * kv_self;
|
| 211 |
+
};
|
| 212 |
+
|
| 213 |
+
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
| 214 |
+
public:
|
| 215 |
+
llm_graph_input_cross_embd(
|
| 216 |
+
const llama_cross * cross) : cross(cross) {}
|
| 217 |
+
virtual ~llm_graph_input_cross_embd() = default;
|
| 218 |
+
|
| 219 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 220 |
+
|
| 221 |
+
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
|
| 222 |
+
|
| 223 |
+
const llama_cross * cross;
|
| 224 |
+
};
|
| 225 |
+
|
| 226 |
+
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
|
| 227 |
+
public:
|
| 228 |
+
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
|
| 229 |
+
hparams(hparams),
|
| 230 |
+
cparams(cparams) {
|
| 231 |
+
}
|
| 232 |
+
~llm_graph_input_attn_no_cache() = default;
|
| 233 |
+
|
| 234 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 235 |
+
|
| 236 |
+
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
| 237 |
+
|
| 238 |
+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
| 239 |
+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
| 240 |
+
|
| 241 |
+
const llama_hparams & hparams;
|
| 242 |
+
const llama_cparams & cparams;
|
| 243 |
+
};
|
| 244 |
+
|
| 245 |
+
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
|
| 246 |
+
public:
|
| 247 |
+
llm_graph_input_attn_kv_unified(
|
| 248 |
+
const llama_hparams & hparams,
|
| 249 |
+
const llama_cparams & cparams,
|
| 250 |
+
const llama_kv_cache_unified * kv_self) :
|
| 251 |
+
hparams(hparams),
|
| 252 |
+
cparams(cparams),
|
| 253 |
+
kv_self(kv_self) {
|
| 254 |
+
}
|
| 255 |
+
~llm_graph_input_attn_kv_unified() = default;
|
| 256 |
+
|
| 257 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 258 |
+
|
| 259 |
+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 260 |
+
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 261 |
+
|
| 262 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
| 263 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
| 264 |
+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
| 265 |
+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
| 266 |
+
|
| 267 |
+
const llama_hparams & hparams;
|
| 268 |
+
const llama_cparams & cparams;
|
| 269 |
+
|
| 270 |
+
const llama_kv_cache_unified * kv_self;
|
| 271 |
+
};
|
| 272 |
+
|
| 273 |
+
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
| 274 |
+
public:
|
| 275 |
+
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
| 276 |
+
~llm_graph_input_attn_cross() = default;
|
| 277 |
+
|
| 278 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 279 |
+
|
| 280 |
+
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
| 281 |
+
|
| 282 |
+
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
| 283 |
+
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
| 284 |
+
|
| 285 |
+
const llama_cross * cross = nullptr;
|
| 286 |
+
};
|
| 287 |
+
|
| 288 |
+
//
|
| 289 |
+
// llm_graph_result
|
| 290 |
+
//
|
| 291 |
+
|
| 292 |
+
// these objects deliver the result from the graph build process back to the llama_context
|
| 293 |
+
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
|
| 294 |
+
// specific data, by calling the set_inputs() method
|
| 295 |
+
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
|
| 296 |
+
// these are used by the llama_context to extact the relevant data, based on the compute parameters
|
| 297 |
+
|
| 298 |
+
class llm_graph_result_i {
|
| 299 |
+
public:
|
| 300 |
+
virtual ~llm_graph_result_i() = default;
|
| 301 |
+
|
| 302 |
+
virtual ggml_tensor * get_logits() = 0;
|
| 303 |
+
virtual ggml_tensor * get_embd() = 0;
|
| 304 |
+
virtual ggml_tensor * get_embd_pooled() = 0;
|
| 305 |
+
|
| 306 |
+
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
|
| 307 |
+
};
|
| 308 |
+
|
| 309 |
+
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class llm_graph_result : public llm_graph_result_i {
|
| 313 |
+
public:
|
| 314 |
+
virtual ~llm_graph_result() = default;
|
| 315 |
+
|
| 316 |
+
ggml_tensor * get_logits() override { return t_logits; }
|
| 317 |
+
ggml_tensor * get_embd() override { return t_embd; }
|
| 318 |
+
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
| 319 |
+
|
| 320 |
+
void set_inputs(const llama_ubatch * ubatch) override {
|
| 321 |
+
for (auto & input : inputs) {
|
| 322 |
+
input->set_input(ubatch);
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
|
| 327 |
+
inputs.emplace_back(std::move(input));
|
| 328 |
+
return inputs.back().get();
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// important graph nodes
|
| 332 |
+
ggml_tensor * t_logits = nullptr;
|
| 333 |
+
ggml_tensor * t_embd = nullptr;
|
| 334 |
+
ggml_tensor * t_embd_pooled = nullptr;
|
| 335 |
+
|
| 336 |
+
std::vector<llm_graph_input_ptr> inputs;
|
| 337 |
+
};
|
| 338 |
+
|
| 339 |
+
//
|
| 340 |
+
// llm_graph_context
|
| 341 |
+
//
|
| 342 |
+
|
| 343 |
+
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
| 344 |
+
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
|
| 345 |
+
|
| 346 |
+
struct llm_graph_params {
|
| 347 |
+
ggml_context * ctx;
|
| 348 |
+
|
| 349 |
+
const llm_arch arch;
|
| 350 |
+
|
| 351 |
+
const llama_hparams & hparams;
|
| 352 |
+
const llama_cparams & cparams;
|
| 353 |
+
const llama_ubatch & ubatch;
|
| 354 |
+
|
| 355 |
+
ggml_backend_sched * sched;
|
| 356 |
+
ggml_backend * backend_cpu;
|
| 357 |
+
|
| 358 |
+
const llama_adapter_cvec * cvec;
|
| 359 |
+
const llama_adapter_loras * loras;
|
| 360 |
+
const llama_memory_i * memory;
|
| 361 |
+
const llama_cross * cross;
|
| 362 |
+
|
| 363 |
+
int32_t n_outputs;
|
| 364 |
+
|
| 365 |
+
const llm_graph_cb & cb;
|
| 366 |
+
};
|
| 367 |
+
|
| 368 |
+
struct llm_graph_context {
|
| 369 |
+
const llm_arch arch;
|
| 370 |
+
|
| 371 |
+
const llama_hparams & hparams;
|
| 372 |
+
const llama_cparams & cparams;
|
| 373 |
+
const llama_ubatch & ubatch;
|
| 374 |
+
|
| 375 |
+
const int64_t n_embd;
|
| 376 |
+
const int64_t n_layer;
|
| 377 |
+
const int64_t n_rot;
|
| 378 |
+
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
| 379 |
+
const int64_t n_ctx_per_seq;
|
| 380 |
+
const int64_t n_head;
|
| 381 |
+
const int64_t n_head_kv;
|
| 382 |
+
const int64_t n_embd_head_k;
|
| 383 |
+
const int64_t n_embd_k_gqa;
|
| 384 |
+
const int64_t n_embd_head_v;
|
| 385 |
+
const int64_t n_embd_v_gqa;
|
| 386 |
+
const int64_t n_expert;
|
| 387 |
+
const int64_t n_expert_used;
|
| 388 |
+
|
| 389 |
+
const float freq_base;
|
| 390 |
+
const float freq_scale;
|
| 391 |
+
const float ext_factor;
|
| 392 |
+
const float attn_factor;
|
| 393 |
+
const float beta_fast;
|
| 394 |
+
const float beta_slow;
|
| 395 |
+
const float norm_eps;
|
| 396 |
+
const float norm_rms_eps;
|
| 397 |
+
|
| 398 |
+
const int32_t n_tokens;
|
| 399 |
+
const int32_t n_outputs;
|
| 400 |
+
const int32_t n_ctx_orig; // yarn
|
| 401 |
+
|
| 402 |
+
const enum llama_pooling_type pooling_type;
|
| 403 |
+
const enum llama_rope_type rope_type;
|
| 404 |
+
|
| 405 |
+
ggml_context * ctx0 = nullptr;
|
| 406 |
+
|
| 407 |
+
ggml_backend_sched * sched;
|
| 408 |
+
|
| 409 |
+
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
| 410 |
+
|
| 411 |
+
const llama_adapter_cvec * cvec;
|
| 412 |
+
const llama_adapter_loras * loras;
|
| 413 |
+
const llama_memory_i * memory;
|
| 414 |
+
const llama_cross * cross;
|
| 415 |
+
|
| 416 |
+
const llm_graph_cb & cb_func;
|
| 417 |
+
|
| 418 |
+
std::unique_ptr<llm_graph_result> res;
|
| 419 |
+
|
| 420 |
+
llm_graph_context(const llm_graph_params & params);
|
| 421 |
+
|
| 422 |
+
int64_t n_pos_per_token() const;
|
| 423 |
+
|
| 424 |
+
void cb(ggml_tensor * cur, const char * name, int il) const;
|
| 425 |
+
|
| 426 |
+
//
|
| 427 |
+
// common
|
| 428 |
+
//
|
| 429 |
+
|
| 430 |
+
ggml_tensor * build_cvec(
|
| 431 |
+
ggml_tensor * cur,
|
| 432 |
+
int il) const;
|
| 433 |
+
|
| 434 |
+
// do mat_mul, while optionally apply lora
|
| 435 |
+
ggml_tensor * build_lora_mm(
|
| 436 |
+
ggml_tensor * w,
|
| 437 |
+
ggml_tensor * cur) const;
|
| 438 |
+
|
| 439 |
+
// do mat_mul_id, while optionally apply lora
|
| 440 |
+
ggml_tensor * build_lora_mm_id(
|
| 441 |
+
ggml_tensor * w, // ggml_tensor * as
|
| 442 |
+
ggml_tensor * cur, // ggml_tensor * b
|
| 443 |
+
ggml_tensor * ids) const;
|
| 444 |
+
|
| 445 |
+
ggml_tensor * build_norm(
|
| 446 |
+
ggml_tensor * cur,
|
| 447 |
+
ggml_tensor * mw,
|
| 448 |
+
ggml_tensor * mb,
|
| 449 |
+
llm_norm_type type,
|
| 450 |
+
int il) const;
|
| 451 |
+
|
| 452 |
+
ggml_tensor * build_ffn(
|
| 453 |
+
ggml_tensor * cur,
|
| 454 |
+
ggml_tensor * up,
|
| 455 |
+
ggml_tensor * up_b,
|
| 456 |
+
ggml_tensor * up_s,
|
| 457 |
+
ggml_tensor * gate,
|
| 458 |
+
ggml_tensor * gate_b,
|
| 459 |
+
ggml_tensor * gate_s,
|
| 460 |
+
ggml_tensor * down,
|
| 461 |
+
ggml_tensor * down_b,
|
| 462 |
+
ggml_tensor * down_s,
|
| 463 |
+
ggml_tensor * act_scales,
|
| 464 |
+
llm_ffn_op_type type_op,
|
| 465 |
+
llm_ffn_gate_type type_gate,
|
| 466 |
+
int il) const;
|
| 467 |
+
|
| 468 |
+
ggml_tensor * build_moe_ffn(
|
| 469 |
+
ggml_tensor * cur,
|
| 470 |
+
ggml_tensor * gate_inp,
|
| 471 |
+
ggml_tensor * up_exps,
|
| 472 |
+
ggml_tensor * gate_exps,
|
| 473 |
+
ggml_tensor * down_exps,
|
| 474 |
+
ggml_tensor * exp_probs_b,
|
| 475 |
+
int64_t n_expert,
|
| 476 |
+
int64_t n_expert_used,
|
| 477 |
+
llm_ffn_op_type type_op,
|
| 478 |
+
bool norm_w,
|
| 479 |
+
bool scale_w,
|
| 480 |
+
float w_scale,
|
| 481 |
+
llama_expert_gating_func_type gating_op,
|
| 482 |
+
int il) const;
|
| 483 |
+
|
| 484 |
+
//
|
| 485 |
+
// inputs
|
| 486 |
+
//
|
| 487 |
+
|
| 488 |
+
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
|
| 489 |
+
ggml_tensor * build_inp_pos() const;
|
| 490 |
+
ggml_tensor * build_inp_attn_scale() const;
|
| 491 |
+
ggml_tensor * build_inp_out_ids() const;
|
| 492 |
+
ggml_tensor * build_inp_mean() const;
|
| 493 |
+
ggml_tensor * build_inp_cls() const;
|
| 494 |
+
ggml_tensor * build_inp_s_copy() const;
|
| 495 |
+
ggml_tensor * build_inp_s_mask() const;
|
| 496 |
+
|
| 497 |
+
ggml_tensor * build_inp_cross_embd() const;
|
| 498 |
+
ggml_tensor * build_inp_pos_bucket_enc() const;
|
| 499 |
+
ggml_tensor * build_inp_pos_bucket_dec() const;
|
| 500 |
+
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
| 501 |
+
|
| 502 |
+
//
|
| 503 |
+
// attention
|
| 504 |
+
//
|
| 505 |
+
|
| 506 |
+
ggml_tensor * build_attn_mha(
|
| 507 |
+
ggml_cgraph * gf,
|
| 508 |
+
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
| 509 |
+
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
| 510 |
+
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
| 511 |
+
ggml_tensor * kq_b,
|
| 512 |
+
ggml_tensor * kq_mask,
|
| 513 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 514 |
+
bool v_trans,
|
| 515 |
+
float kq_scale) const;
|
| 516 |
+
|
| 517 |
+
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
| 518 |
+
|
| 519 |
+
ggml_tensor * build_attn(
|
| 520 |
+
llm_graph_input_attn_no_cache * inp,
|
| 521 |
+
ggml_cgraph * gf,
|
| 522 |
+
ggml_tensor * wo,
|
| 523 |
+
ggml_tensor * wo_b,
|
| 524 |
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 525 |
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 526 |
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 527 |
+
ggml_tensor * kq_b,
|
| 528 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 529 |
+
float kq_scale,
|
| 530 |
+
int il) const;
|
| 531 |
+
|
| 532 |
+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
| 533 |
+
|
| 534 |
+
ggml_tensor * build_attn(
|
| 535 |
+
llm_graph_input_attn_kv_unified * inp,
|
| 536 |
+
ggml_cgraph * gf,
|
| 537 |
+
ggml_tensor * wo,
|
| 538 |
+
ggml_tensor * wo_b,
|
| 539 |
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 540 |
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 541 |
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 542 |
+
ggml_tensor * kq_b,
|
| 543 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 544 |
+
float kq_scale,
|
| 545 |
+
int il) const;
|
| 546 |
+
|
| 547 |
+
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
| 548 |
+
|
| 549 |
+
ggml_tensor * build_attn(
|
| 550 |
+
llm_graph_input_attn_cross * inp,
|
| 551 |
+
ggml_cgraph * gf,
|
| 552 |
+
ggml_tensor * wo,
|
| 553 |
+
ggml_tensor * wo_b,
|
| 554 |
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 555 |
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 556 |
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 557 |
+
ggml_tensor * kq_b,
|
| 558 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 559 |
+
float kq_scale,
|
| 560 |
+
int il) const;
|
| 561 |
+
|
| 562 |
+
//
|
| 563 |
+
// recurrent
|
| 564 |
+
//
|
| 565 |
+
|
| 566 |
+
ggml_tensor * build_copy_mask_state(
|
| 567 |
+
ggml_cgraph * gf,
|
| 568 |
+
ggml_tensor * s,
|
| 569 |
+
ggml_tensor * state_copy,
|
| 570 |
+
ggml_tensor * state_mask,
|
| 571 |
+
int32_t n_state,
|
| 572 |
+
int32_t n_seqs) const;
|
| 573 |
+
|
| 574 |
+
ggml_tensor * build_rwkv_token_shift_load(
|
| 575 |
+
ggml_cgraph * gf,
|
| 576 |
+
ggml_tensor * state_copy,
|
| 577 |
+
ggml_tensor * state_mask,
|
| 578 |
+
const llama_ubatch & ubatch,
|
| 579 |
+
int il) const;
|
| 580 |
+
|
| 581 |
+
ggml_tensor * build_rwkv_token_shift_store(
|
| 582 |
+
ggml_tensor * token_shift,
|
| 583 |
+
const llama_ubatch & ubatch,
|
| 584 |
+
int il) const;
|
| 585 |
+
|
| 586 |
+
//
|
| 587 |
+
// pooling
|
| 588 |
+
//
|
| 589 |
+
|
| 590 |
+
void build_pooling(
|
| 591 |
+
ggml_cgraph * gf,
|
| 592 |
+
ggml_tensor * cls,
|
| 593 |
+
ggml_tensor * cls_b,
|
| 594 |
+
ggml_tensor * cls_out,
|
| 595 |
+
ggml_tensor * cls_out_b) const;
|
| 596 |
+
};
|
examples/talk-llama/llama-hparams.cpp
CHANGED
|
@@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
| 69 |
// corresponds to Mamba's ssm_states size
|
| 70 |
return ssm_d_state * ssm_d_inner;
|
| 71 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
// corresponds to Mamba's ssm_states size
|
| 70 |
return ssm_d_state * ssm_d_inner;
|
| 71 |
}
|
| 72 |
+
|
| 73 |
+
bool llama_hparams::is_swa(uint32_t il) const {
|
| 74 |
+
if (il < n_layer) {
|
| 75 |
+
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
GGML_ABORT("fatal error");
|
| 79 |
+
}
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -36,12 +36,17 @@ struct llama_hparams {
|
|
| 36 |
uint32_t n_layer;
|
| 37 |
uint32_t n_rot;
|
| 38 |
uint32_t n_swa = 0; // sliding window attention (SWA)
|
|
|
|
| 39 |
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
| 40 |
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
| 41 |
uint32_t n_expert = 0;
|
| 42 |
uint32_t n_expert_used = 0;
|
| 43 |
uint32_t n_rel_attn_bkts = 0;
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
// for WavTokenizer
|
| 46 |
struct llama_hparams_posnet posnet;
|
| 47 |
struct llama_hparams_convnext convnext;
|
|
@@ -75,10 +80,16 @@ struct llama_hparams {
|
|
| 75 |
uint32_t time_decay_extra_dim = 0;
|
| 76 |
uint32_t wkv_head_size = 0;
|
| 77 |
uint32_t token_shift_count = 2;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
float rope_attn_factor = 1.0f;
|
| 80 |
float rope_freq_base_train;
|
|
|
|
| 81 |
float rope_freq_scale_train;
|
|
|
|
| 82 |
uint32_t n_ctx_orig_yarn;
|
| 83 |
float rope_yarn_log_mul;
|
| 84 |
|
|
@@ -105,6 +116,14 @@ struct llama_hparams {
|
|
| 105 |
bool use_alibi = false;
|
| 106 |
bool attn_soft_cap = false;
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
| 109 |
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
| 110 |
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
|
@@ -133,6 +152,8 @@ struct llama_hparams {
|
|
| 133 |
|
| 134 |
// dimension of the recurrent state embeddings
|
| 135 |
uint32_t n_embd_v_s() const;
|
|
|
|
|
|
|
| 136 |
};
|
| 137 |
|
| 138 |
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
|
|
|
| 36 |
uint32_t n_layer;
|
| 37 |
uint32_t n_rot;
|
| 38 |
uint32_t n_swa = 0; // sliding window attention (SWA)
|
| 39 |
+
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
| 40 |
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
| 41 |
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
| 42 |
uint32_t n_expert = 0;
|
| 43 |
uint32_t n_expert_used = 0;
|
| 44 |
uint32_t n_rel_attn_bkts = 0;
|
| 45 |
|
| 46 |
+
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
| 47 |
+
uint32_t n_embd_head_k_mla = 0;
|
| 48 |
+
uint32_t n_embd_head_v_mla = 0;
|
| 49 |
+
|
| 50 |
// for WavTokenizer
|
| 51 |
struct llama_hparams_posnet posnet;
|
| 52 |
struct llama_hparams_convnext convnext;
|
|
|
|
| 80 |
uint32_t time_decay_extra_dim = 0;
|
| 81 |
uint32_t wkv_head_size = 0;
|
| 82 |
uint32_t token_shift_count = 2;
|
| 83 |
+
uint32_t n_lora_decay = 0;
|
| 84 |
+
uint32_t n_lora_iclr = 0;
|
| 85 |
+
uint32_t n_lora_value_res_mix = 0;
|
| 86 |
+
uint32_t n_lora_gate = 0;
|
| 87 |
|
| 88 |
float rope_attn_factor = 1.0f;
|
| 89 |
float rope_freq_base_train;
|
| 90 |
+
float rope_freq_base_train_swa;
|
| 91 |
float rope_freq_scale_train;
|
| 92 |
+
float rope_freq_scale_train_swa;
|
| 93 |
uint32_t n_ctx_orig_yarn;
|
| 94 |
float rope_yarn_log_mul;
|
| 95 |
|
|
|
|
| 116 |
bool use_alibi = false;
|
| 117 |
bool attn_soft_cap = false;
|
| 118 |
|
| 119 |
+
uint32_t n_moe_layer_step = 0;
|
| 120 |
+
bool use_kq_norm = true;
|
| 121 |
+
uint32_t n_attn_chunk = 0;
|
| 122 |
+
// values below seems to be fixed on llama4
|
| 123 |
+
uint32_t n_no_rope_layer_step = 4;
|
| 124 |
+
uint32_t n_attn_temp_floor_scale = 8192;
|
| 125 |
+
float f_attn_temp_scale = 0.1;
|
| 126 |
+
|
| 127 |
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
| 128 |
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
| 129 |
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
|
|
|
| 152 |
|
| 153 |
// dimension of the recurrent state embeddings
|
| 154 |
uint32_t n_embd_v_s() const;
|
| 155 |
+
|
| 156 |
+
bool is_swa(uint32_t il) const;
|
| 157 |
};
|
| 158 |
|
| 159 |
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
examples/talk-llama/llama-impl.h
CHANGED
|
@@ -6,13 +6,13 @@
|
|
| 6 |
#include <vector>
|
| 7 |
|
| 8 |
#ifdef __GNUC__
|
| 9 |
-
#
|
| 10 |
-
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
|
|
|
|
|
|
|
|
|
| 11 |
#else
|
| 12 |
-
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
| 13 |
-
#endif
|
| 14 |
-
#else
|
| 15 |
-
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
| 16 |
#endif
|
| 17 |
|
| 18 |
//
|
|
|
|
| 6 |
#include <vector>
|
| 7 |
|
| 8 |
#ifdef __GNUC__
|
| 9 |
+
# if defined(__MINGW32__) && !defined(__clang__)
|
| 10 |
+
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
| 11 |
+
# else
|
| 12 |
+
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
| 13 |
+
# endif
|
| 14 |
#else
|
| 15 |
+
# define LLAMA_ATTRIBUTE_FORMAT(...)
|
|
|
|
|
|
|
|
|
|
| 16 |
#endif
|
| 17 |
|
| 18 |
//
|
examples/talk-llama/llama-io.cpp
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-io.h"
|
| 2 |
+
|
| 3 |
+
void llama_io_write_i::write_string(const std::string & str) {
|
| 4 |
+
uint32_t str_size = str.size();
|
| 5 |
+
|
| 6 |
+
write(&str_size, sizeof(str_size));
|
| 7 |
+
write(str.data(), str_size);
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
void llama_io_read_i::read_string(std::string & str) {
|
| 11 |
+
uint32_t str_size;
|
| 12 |
+
read_to(&str_size, sizeof(str_size));
|
| 13 |
+
|
| 14 |
+
str.assign((const char *) read(str_size), str_size);
|
| 15 |
+
}
|
examples/talk-llama/llama-io.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstddef>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
struct ggml_tensor;
|
| 8 |
+
|
| 9 |
+
class llama_io_write_i {
|
| 10 |
+
public:
|
| 11 |
+
llama_io_write_i() = default;
|
| 12 |
+
virtual ~llama_io_write_i() = default;
|
| 13 |
+
|
| 14 |
+
virtual void write(const void * src, size_t size) = 0;
|
| 15 |
+
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
| 16 |
+
|
| 17 |
+
// bytes written so far
|
| 18 |
+
virtual size_t n_bytes() = 0;
|
| 19 |
+
|
| 20 |
+
void write_string(const std::string & str);
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
class llama_io_read_i {
|
| 24 |
+
public:
|
| 25 |
+
llama_io_read_i() = default;
|
| 26 |
+
virtual ~llama_io_read_i() = default;
|
| 27 |
+
|
| 28 |
+
virtual const uint8_t * read(size_t size) = 0;
|
| 29 |
+
virtual void read_to(void * dst, size_t size) = 0;
|
| 30 |
+
|
| 31 |
+
// bytes read so far
|
| 32 |
+
virtual size_t n_bytes() = 0;
|
| 33 |
+
|
| 34 |
+
void read_string(std::string & str);
|
| 35 |
+
};
|
examples/talk-llama/llama-kv-cache.cpp
CHANGED
|
@@ -6,86 +6,90 @@
|
|
| 6 |
#include "llama-model.h"
|
| 7 |
|
| 8 |
#include <algorithm>
|
|
|
|
| 9 |
#include <limits>
|
| 10 |
#include <map>
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
|
| 15 |
-
// the FA kernels require padding to avoid extra runtime boundary checks
|
| 16 |
-
return cparams.flash_attn ? 256u : 32u;
|
| 17 |
}
|
| 18 |
|
| 19 |
-
bool
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
bool offload) {
|
| 27 |
-
const struct llama_hparams & hparams = model.hparams;
|
| 28 |
-
|
| 29 |
const int32_t n_layer = hparams.n_layer;
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
|
| 38 |
-
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer,
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
// create a context for each buffer type
|
| 51 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 52 |
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 53 |
auto it = ctx_map.find(buft);
|
| 54 |
if (it == ctx_map.end()) {
|
| 55 |
-
|
| 56 |
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
| 57 |
/*.mem_buffer =*/ NULL,
|
| 58 |
/*.no_alloc =*/ true,
|
| 59 |
};
|
|
|
|
| 60 |
ggml_context * ctx = ggml_init(params);
|
| 61 |
if (!ctx) {
|
| 62 |
return nullptr;
|
| 63 |
}
|
|
|
|
| 64 |
ctx_map[buft] = ctx;
|
| 65 |
-
|
|
|
|
| 66 |
return ctx;
|
| 67 |
}
|
|
|
|
| 68 |
return it->second;
|
| 69 |
};
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
|
| 74 |
for (int i = 0; i < n_layer; i++) {
|
| 75 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
| 76 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
| 77 |
|
| 78 |
-
|
| 79 |
|
| 80 |
ggml_backend_buffer_type_t buft;
|
| 81 |
if (offload) {
|
| 82 |
auto * dev = model.dev_layer(i);
|
| 83 |
buft = ggml_backend_dev_buffer_type(dev);
|
|
|
|
|
|
|
| 84 |
} else {
|
| 85 |
buft = ggml_backend_cpu_buffer_type();
|
| 86 |
}
|
| 87 |
-
ggml_context * ctx = ctx_for_buft(buft);
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
if (!ctx) {
|
| 90 |
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
|
| 91 |
return false;
|
|
@@ -95,8 +99,8 @@ bool llama_kv_cache_init(
|
|
| 95 |
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
| 96 |
ggml_format_name(k, "cache_k_l%d", i);
|
| 97 |
ggml_format_name(v, "cache_v_l%d", i);
|
| 98 |
-
|
| 99 |
-
|
| 100 |
}
|
| 101 |
|
| 102 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
@@ -111,20 +115,403 @@ bool llama_kv_cache_init(
|
|
| 111 |
}
|
| 112 |
ggml_backend_buffer_clear(buf, 0);
|
| 113 |
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
}
|
| 116 |
|
| 117 |
return true;
|
| 118 |
}
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 124 |
const uint32_t n_seqs = ubatch.n_seqs;
|
| 125 |
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 126 |
|
| 127 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
// For recurrent state architectures (like Mamba or RWKV),
|
| 129 |
// each cache cell can store the state for a whole sequence.
|
| 130 |
// A slot should be always be contiguous.
|
|
@@ -132,7 +519,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 132 |
// can only process batches with an equal number of new tokens in each sequence
|
| 133 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 134 |
|
| 135 |
-
int32_t min =
|
| 136 |
int32_t max = 0;
|
| 137 |
|
| 138 |
// everything should fit if all seq_ids are smaller than the max
|
|
@@ -141,16 +528,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 141 |
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 142 |
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 143 |
|
| 144 |
-
if (seq_id < 0 || (uint32_t) seq_id >=
|
| 145 |
// too big seq_id
|
| 146 |
// TODO: would it be possible to resize the cache instead?
|
| 147 |
-
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id,
|
| 148 |
-
return
|
| 149 |
}
|
| 150 |
if (j > 0) {
|
| 151 |
-
llama_kv_cell & seq =
|
| 152 |
if (seq.tail >= 0) {
|
| 153 |
-
llama_kv_cell & cell =
|
| 154 |
// clear cells from seq_ids that become shared
|
| 155 |
// (should not normally happen, but let's handle it anyway)
|
| 156 |
cell.seq_id.erase(seq_id);
|
|
@@ -158,7 +545,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 158 |
if (cell.seq_id.empty()) {
|
| 159 |
cell.pos = -1;
|
| 160 |
cell.src = -1;
|
| 161 |
-
|
| 162 |
}
|
| 163 |
}
|
| 164 |
}
|
|
@@ -168,9 +555,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 168 |
#ifndef NDEBUG
|
| 169 |
{
|
| 170 |
std::vector<int32_t> tails_verif;
|
| 171 |
-
tails_verif.assign(
|
| 172 |
-
for (uint32_t i = 0; i <
|
| 173 |
-
llama_kv_cell & cell =
|
| 174 |
for (llama_seq_id seq_id : cell.seq_id) {
|
| 175 |
if (tails_verif[seq_id] != -1) {
|
| 176 |
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
|
@@ -178,20 +565,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 178 |
tails_verif[seq_id] = i;
|
| 179 |
}
|
| 180 |
}
|
| 181 |
-
for (uint32_t i = 0; i <
|
| 182 |
-
if (tails_verif[i] !=
|
| 183 |
-
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i,
|
| 184 |
}
|
| 185 |
}
|
| 186 |
}
|
| 187 |
#endif
|
| 188 |
|
| 189 |
// find next empty cell
|
| 190 |
-
uint32_t next_empty_cell =
|
| 191 |
|
| 192 |
-
for (uint32_t i = 0; i <
|
| 193 |
-
if (next_empty_cell >=
|
| 194 |
-
llama_kv_cell & cell =
|
| 195 |
if (cell.is_empty()) { break; }
|
| 196 |
next_empty_cell += 1;
|
| 197 |
}
|
|
@@ -199,20 +586,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 199 |
// find usable cell range
|
| 200 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 201 |
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 202 |
-
llama_kv_cell & seq_meta =
|
| 203 |
bool has_cell = false;
|
| 204 |
if (seq_meta.tail >= 0) {
|
| 205 |
-
llama_kv_cell & cell =
|
| 206 |
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 207 |
// does this seq_id "own" the cell?
|
| 208 |
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 209 |
}
|
| 210 |
if (!has_cell) {
|
| 211 |
-
llama_kv_cell & empty_cell =
|
| 212 |
GGML_ASSERT(empty_cell.is_empty());
|
| 213 |
// copy old tail into the empty cell
|
| 214 |
if (seq_meta.tail >= 0) {
|
| 215 |
-
llama_kv_cell & orig_cell =
|
| 216 |
empty_cell.pos = orig_cell.pos;
|
| 217 |
empty_cell.src = orig_cell.src;
|
| 218 |
orig_cell.seq_id.erase(seq_id);
|
|
@@ -222,9 +609,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 222 |
// find next empty cell
|
| 223 |
if (s + 1 < n_seqs) {
|
| 224 |
next_empty_cell += 1;
|
| 225 |
-
for (uint32_t i = 0; i <
|
| 226 |
-
if (next_empty_cell >=
|
| 227 |
-
llama_kv_cell & cell =
|
| 228 |
if (cell.is_empty()) { break; }
|
| 229 |
next_empty_cell += 1;
|
| 230 |
}
|
|
@@ -237,10 +624,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 237 |
// gather and re-order
|
| 238 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 239 |
int32_t dst_id = s + min;
|
| 240 |
-
int32_t src_id =
|
| 241 |
if (dst_id != src_id) {
|
| 242 |
-
llama_kv_cell & dst_cell =
|
| 243 |
-
llama_kv_cell & src_cell =
|
| 244 |
|
| 245 |
std::swap(dst_cell.pos, src_cell.pos);
|
| 246 |
std::swap(dst_cell.src, src_cell.src);
|
|
@@ -248,10 +635,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 248 |
|
| 249 |
// swap tails (assuming they NEVER overlap)
|
| 250 |
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
| 251 |
-
|
| 252 |
}
|
| 253 |
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
| 254 |
-
|
| 255 |
}
|
| 256 |
}
|
| 257 |
}
|
|
@@ -260,7 +647,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 260 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 261 |
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 262 |
int32_t cell_id = s + min;
|
| 263 |
-
llama_kv_cell & cell =
|
| 264 |
|
| 265 |
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 266 |
// What should happen when the pos backtracks or skips a value?
|
|
@@ -273,41 +660,42 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 273 |
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
| 274 |
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 275 |
cell.seq_id.insert(seq_id);
|
| 276 |
-
|
| 277 |
}
|
| 278 |
}
|
| 279 |
|
| 280 |
// allow getting the range of used cells, from head to head + n
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
| 285 |
|
| 286 |
// sanity check
|
| 287 |
-
return
|
| 288 |
}
|
|
|
|
| 289 |
// otherwise, one cell per token.
|
| 290 |
|
| 291 |
-
if (n_tokens >
|
| 292 |
-
LLAMA_LOG_ERROR("%s: n_tokens
|
| 293 |
-
return
|
| 294 |
}
|
| 295 |
|
| 296 |
uint32_t n_tested = 0;
|
| 297 |
|
| 298 |
while (true) {
|
| 299 |
-
if (
|
| 300 |
-
n_tested +=
|
| 301 |
-
|
| 302 |
continue;
|
| 303 |
}
|
| 304 |
|
| 305 |
bool found = true;
|
| 306 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 307 |
-
if (
|
| 308 |
found = false;
|
| 309 |
-
|
| 310 |
-
n_tested
|
| 311 |
break;
|
| 312 |
}
|
| 313 |
}
|
|
@@ -316,31 +704,38 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
| 316 |
break;
|
| 317 |
}
|
| 318 |
|
| 319 |
-
if (n_tested >=
|
| 320 |
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 321 |
-
return
|
| 322 |
}
|
| 323 |
}
|
| 324 |
|
| 325 |
for (uint32_t s = 0; s < n_seqs; s++) {
|
| 326 |
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
| 327 |
uint32_t k = s*n_seq_tokens + i;
|
| 328 |
-
|
| 329 |
|
| 330 |
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
| 331 |
-
|
| 332 |
}
|
| 333 |
}
|
| 334 |
}
|
| 335 |
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
-
|
|
|
|
|
|
|
| 339 |
}
|
| 340 |
|
| 341 |
-
uint32_t
|
| 342 |
-
for (uint32_t i =
|
| 343 |
-
const llama_kv_cell & cell =
|
| 344 |
|
| 345 |
if (cell.pos >= 0 && !cell.is_empty()) {
|
| 346 |
return i;
|
|
@@ -350,289 +745,549 @@ uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
|
| 350 |
return 0;
|
| 351 |
}
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
cache.cells[i].tail = -1;
|
| 359 |
}
|
| 360 |
-
cache.head = 0;
|
| 361 |
-
cache.used = 0;
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
}
|
|
|
|
|
|
|
| 366 |
}
|
| 367 |
|
| 368 |
-
bool
|
| 369 |
-
|
| 370 |
-
llama_seq_id seq_id,
|
| 371 |
-
llama_pos p0,
|
| 372 |
-
llama_pos p1) {
|
| 373 |
-
uint32_t new_head = cache.size;
|
| 374 |
|
| 375 |
-
|
| 376 |
-
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
}
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
// invalidate tails which will be cleared
|
| 393 |
-
if (p0 <= cell.pos && cell.pos < p1) {
|
| 394 |
-
tail_id = -1;
|
| 395 |
-
}
|
| 396 |
-
}
|
| 397 |
-
} else {
|
| 398 |
-
// seq_id is negative, then the range should include everything or nothing
|
| 399 |
-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
| 400 |
-
return false;
|
| 401 |
-
}
|
| 402 |
}
|
| 403 |
-
}
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
| 412 |
continue;
|
| 413 |
}
|
| 414 |
-
if (cache.cells[i].is_empty()) {
|
| 415 |
-
// keep count of the number of used cells
|
| 416 |
-
if (cache.cells[i].pos >= 0) cache.used--;
|
| 417 |
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
| 421 |
}
|
| 422 |
}
|
| 423 |
-
}
|
| 424 |
|
| 425 |
-
|
| 426 |
-
|
| 427 |
|
| 428 |
-
|
| 429 |
-
}
|
| 430 |
|
| 431 |
-
|
| 432 |
-
struct llama_kv_cache & cache,
|
| 433 |
-
llama_seq_id seq_id_src,
|
| 434 |
-
llama_seq_id seq_id_dst,
|
| 435 |
-
llama_pos p0,
|
| 436 |
-
llama_pos p1) {
|
| 437 |
-
if (p0 < 0) p0 = 0;
|
| 438 |
-
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
| 439 |
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
llama_kv_cell & tail_src = cache.cells[seq_id_src];
|
| 443 |
-
llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
|
| 444 |
-
if (tail_dst.tail >= 0) {
|
| 445 |
-
// clear destination seq_id if it wasn't empty
|
| 446 |
-
llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
}
|
|
|
|
|
|
|
|
|
|
| 456 |
}
|
| 457 |
-
if (tail_src.tail >= 0) {
|
| 458 |
-
llama_kv_cell & cell_src = cache.cells[tail_src.tail];
|
| 459 |
|
| 460 |
-
|
| 461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
}
|
| 463 |
}
|
| 464 |
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
}
|
| 467 |
-
// otherwise, this is the KV cache of a Transformer-like model
|
| 468 |
|
| 469 |
-
|
| 470 |
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
}
|
| 475 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
}
|
| 477 |
|
| 478 |
-
void
|
| 479 |
-
uint32_t
|
|
|
|
| 480 |
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
cache.cells[i].src = -1;
|
| 489 |
-
cache.cells[i].seq_id.clear();
|
| 490 |
-
if (new_head == cache.size) new_head = i;
|
| 491 |
} else {
|
| 492 |
-
|
| 493 |
-
cache.cells[i].seq_id.insert(seq_id);
|
| 494 |
}
|
|
|
|
| 495 |
}
|
| 496 |
-
|
| 497 |
-
// If we freed up a slot, set head to it so searching can start there.
|
| 498 |
-
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
| 499 |
}
|
| 500 |
|
| 501 |
-
void
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
uint32_t new_head = cache.size;
|
| 508 |
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
// If there is no range then return early to avoid looping over the cache.
|
| 512 |
-
if (p0 == p1) return;
|
| 513 |
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
const int32_t tail_id = cache.cells[seq_id].tail;
|
| 518 |
-
if (tail_id >= 0) {
|
| 519 |
-
llama_kv_cell & cell = cache.cells[tail_id];
|
| 520 |
-
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 521 |
-
cell.pos += delta;
|
| 522 |
}
|
| 523 |
}
|
| 524 |
}
|
| 525 |
-
return;
|
| 526 |
}
|
|
|
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
cache.cells[i].pos += delta;
|
| 532 |
-
cache.cells[i].delta += delta;
|
| 533 |
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
}
|
| 545 |
}
|
| 546 |
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
}
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
llama_pos p0,
|
| 556 |
-
llama_pos p1,
|
| 557 |
-
int d) {
|
| 558 |
-
if (p0 < 0) p0 = 0;
|
| 559 |
-
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
| 560 |
-
// If there is no range then return early to avoid looping over the cache.
|
| 561 |
-
if (p0 == p1) return;
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
}
|
| 572 |
}
|
| 573 |
}
|
| 574 |
-
return;
|
| 575 |
}
|
|
|
|
| 576 |
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
}
|
| 587 |
-
}
|
| 588 |
-
}
|
| 589 |
|
| 590 |
-
|
| 591 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
}
|
|
|
|
|
|
|
|
|
|
| 597 |
}
|
| 598 |
|
| 599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
}
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
}
|
| 606 |
-
}
|
| 607 |
|
| 608 |
-
|
| 609 |
-
|
|
|
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
}
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
|
|
|
| 617 |
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
}
|
| 625 |
|
| 626 |
//
|
| 627 |
// kv cache view
|
| 628 |
//
|
| 629 |
|
| 630 |
-
|
| 631 |
-
|
| 632 |
/*.n_cells = */ 0,
|
| 633 |
/*.n_seq_max = */ n_seq_max,
|
| 634 |
/*.token_count = */ 0,
|
| 635 |
-
/*.used_cells = */
|
| 636 |
/*.max_contiguous = */ 0,
|
| 637 |
/*.max_contiguous_idx = */ -1,
|
| 638 |
/*.cells = */ nullptr,
|
|
@@ -642,7 +1297,7 @@ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache
|
|
| 642 |
return result;
|
| 643 |
}
|
| 644 |
|
| 645 |
-
void llama_kv_cache_view_free(
|
| 646 |
if (view->cells != nullptr) {
|
| 647 |
free(view->cells);
|
| 648 |
view->cells = nullptr;
|
|
@@ -653,18 +1308,25 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
|
|
| 653 |
}
|
| 654 |
}
|
| 655 |
|
| 656 |
-
void llama_kv_cache_view_update(
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
| 661 |
-
view->cells = (
|
| 662 |
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
|
| 663 |
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
| 664 |
view->cells_sequences = (llama_seq_id *)p;
|
| 665 |
}
|
| 666 |
|
| 667 |
-
const std::vector<llama_kv_cell> & kv_cells =
|
| 668 |
llama_kv_cache_view_cell * c_curr = view->cells;
|
| 669 |
llama_seq_id * cs_curr = view->cells_sequences;
|
| 670 |
int32_t used_cells = 0;
|
|
@@ -673,7 +1335,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
|
|
| 673 |
uint32_t max_contig = 0;
|
| 674 |
int32_t max_contig_idx = -1;
|
| 675 |
|
| 676 |
-
for (int32_t i = 0; i < int32_t(
|
| 677 |
const size_t curr_size = kv_cells[i].seq_id.size();
|
| 678 |
token_count += curr_size;
|
| 679 |
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
|
@@ -711,8 +1373,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
|
|
| 711 |
view->max_contiguous_idx = max_contig_idx;
|
| 712 |
view->token_count = token_count;
|
| 713 |
view->used_cells = used_cells;
|
| 714 |
-
if (uint32_t(used_cells) !=
|
| 715 |
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
| 716 |
-
__func__,
|
| 717 |
}
|
| 718 |
}
|
|
|
|
| 6 |
#include "llama-model.h"
|
| 7 |
|
| 8 |
#include <algorithm>
|
| 9 |
+
#include <cassert>
|
| 10 |
#include <limits>
|
| 11 |
#include <map>
|
| 12 |
+
#include <stdexcept>
|
| 13 |
|
| 14 |
+
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
| 16 |
|
| 17 |
+
bool llama_kv_cache_unified::init(
|
| 18 |
+
const llama_model & model,
|
| 19 |
+
const llama_cparams & cparams,
|
| 20 |
+
ggml_type type_k,
|
| 21 |
+
ggml_type type_v,
|
| 22 |
+
uint32_t kv_size,
|
| 23 |
+
bool offload) {
|
|
|
|
|
|
|
|
|
|
| 24 |
const int32_t n_layer = hparams.n_layer;
|
| 25 |
|
| 26 |
+
has_shift = false;
|
| 27 |
|
| 28 |
+
recurrent = llama_model_is_recurrent(&model);
|
| 29 |
+
v_trans = !recurrent && !cparams.flash_attn;
|
| 30 |
+
can_shift = !recurrent;
|
| 31 |
|
| 32 |
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
|
| 33 |
+
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
|
| 34 |
|
| 35 |
+
head = 0;
|
| 36 |
+
size = kv_size;
|
| 37 |
+
used = 0;
|
| 38 |
|
| 39 |
+
this->type_k = type_k;
|
| 40 |
+
this->type_v = type_v;
|
| 41 |
|
| 42 |
+
cells.clear();
|
| 43 |
+
cells.resize(kv_size);
|
| 44 |
|
| 45 |
// create a context for each buffer type
|
| 46 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 47 |
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 48 |
auto it = ctx_map.find(buft);
|
| 49 |
if (it == ctx_map.end()) {
|
| 50 |
+
ggml_init_params params = {
|
| 51 |
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
| 52 |
/*.mem_buffer =*/ NULL,
|
| 53 |
/*.no_alloc =*/ true,
|
| 54 |
};
|
| 55 |
+
|
| 56 |
ggml_context * ctx = ggml_init(params);
|
| 57 |
if (!ctx) {
|
| 58 |
return nullptr;
|
| 59 |
}
|
| 60 |
+
|
| 61 |
ctx_map[buft] = ctx;
|
| 62 |
+
ctxs.emplace_back(ctx);
|
| 63 |
+
|
| 64 |
return ctx;
|
| 65 |
}
|
| 66 |
+
|
| 67 |
return it->second;
|
| 68 |
};
|
| 69 |
|
| 70 |
+
k_l.reserve(n_layer);
|
| 71 |
+
v_l.reserve(n_layer);
|
| 72 |
|
| 73 |
for (int i = 0; i < n_layer; i++) {
|
| 74 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
| 75 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
| 76 |
|
| 77 |
+
const char * dev_name = "CPU";
|
| 78 |
|
| 79 |
ggml_backend_buffer_type_t buft;
|
| 80 |
if (offload) {
|
| 81 |
auto * dev = model.dev_layer(i);
|
| 82 |
buft = ggml_backend_dev_buffer_type(dev);
|
| 83 |
+
|
| 84 |
+
dev_name = ggml_backend_dev_name(dev);
|
| 85 |
} else {
|
| 86 |
buft = ggml_backend_cpu_buffer_type();
|
| 87 |
}
|
|
|
|
| 88 |
|
| 89 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
|
| 90 |
+
i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
|
| 91 |
+
|
| 92 |
+
ggml_context * ctx = ctx_for_buft(buft);
|
| 93 |
if (!ctx) {
|
| 94 |
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
|
| 95 |
return false;
|
|
|
|
| 99 |
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
| 100 |
ggml_format_name(k, "cache_k_l%d", i);
|
| 101 |
ggml_format_name(v, "cache_v_l%d", i);
|
| 102 |
+
k_l.push_back(k);
|
| 103 |
+
v_l.push_back(v);
|
| 104 |
}
|
| 105 |
|
| 106 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
|
|
| 115 |
}
|
| 116 |
ggml_backend_buffer_clear(buf, 0);
|
| 117 |
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 118 |
+
bufs.emplace_back(buf);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
return true;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
int32_t llama_kv_cache_unified::get_n_tokens() const {
|
| 125 |
+
int32_t result = 0;
|
| 126 |
+
|
| 127 |
+
for (uint32_t i = 0; i < size; i++) {
|
| 128 |
+
result += cells[i].seq_id.size();
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
return result;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
int32_t llama_kv_cache_unified::get_used_cells() const {
|
| 135 |
+
return used;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
size_t llama_kv_cache_unified::total_size() const {
|
| 139 |
+
size_t size = 0;
|
| 140 |
+
for (const auto & buf : bufs) {
|
| 141 |
+
size += ggml_backend_buffer_get_size(buf.get());
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
return size;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
llama_pos llama_kv_cache_unified::pos_max() const {
|
| 148 |
+
llama_pos pos_max = -1;
|
| 149 |
+
for (const auto & cell : cells) {
|
| 150 |
+
pos_max = std::max(pos_max, cell.pos);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
return pos_max;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
void llama_kv_cache_unified::clear() {
|
| 157 |
+
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 158 |
+
cells[i].pos = -1;
|
| 159 |
+
cells[i].seq_id.clear();
|
| 160 |
+
cells[i].src = -1;
|
| 161 |
+
cells[i].tail = -1;
|
| 162 |
+
}
|
| 163 |
+
head = 0;
|
| 164 |
+
used = 0;
|
| 165 |
+
|
| 166 |
+
for (auto & buf : bufs) {
|
| 167 |
+
ggml_backend_buffer_clear(buf.get(), 0);
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 172 |
+
uint32_t new_head = size;
|
| 173 |
+
|
| 174 |
+
if (p0 < 0) {
|
| 175 |
+
p0 = 0;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if (p1 < 0) {
|
| 179 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// models like Mamba or RWKV can't have a state partially erased
|
| 183 |
+
if (recurrent) {
|
| 184 |
+
if (seq_id >= (int64_t) size) {
|
| 185 |
+
// could be fatal
|
| 186 |
+
return false;
|
| 187 |
+
}
|
| 188 |
+
if (0 <= seq_id) {
|
| 189 |
+
int32_t & tail_id = cells[seq_id].tail;
|
| 190 |
+
if (tail_id >= 0) {
|
| 191 |
+
const llama_kv_cell & cell = cells[tail_id];
|
| 192 |
+
// partial intersection is invalid
|
| 193 |
+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
| 194 |
+
return false;
|
| 195 |
+
}
|
| 196 |
+
// invalidate tails which will be cleared
|
| 197 |
+
if (p0 <= cell.pos && cell.pos < p1) {
|
| 198 |
+
tail_id = -1;
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
} else {
|
| 202 |
+
// seq_id is negative, then the range should include everything or nothing
|
| 203 |
+
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
| 204 |
+
return false;
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
return true;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 212 |
+
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
| 213 |
+
if (seq_id < 0) {
|
| 214 |
+
cells[i].seq_id.clear();
|
| 215 |
+
} else if (cells[i].has_seq_id(seq_id)) {
|
| 216 |
+
cells[i].seq_id.erase(seq_id);
|
| 217 |
+
} else {
|
| 218 |
+
continue;
|
| 219 |
+
}
|
| 220 |
+
if (cells[i].is_empty()) {
|
| 221 |
+
// keep count of the number of used cells
|
| 222 |
+
if (cells[i].pos >= 0) {
|
| 223 |
+
used--;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
cells[i].pos = -1;
|
| 227 |
+
cells[i].src = -1;
|
| 228 |
+
|
| 229 |
+
if (new_head == size) {
|
| 230 |
+
new_head = i;
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 237 |
+
if (new_head != size && new_head < head) {
|
| 238 |
+
head = new_head;
|
| 239 |
}
|
| 240 |
|
| 241 |
return true;
|
| 242 |
}
|
| 243 |
|
| 244 |
+
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 245 |
+
if (seq_id_src == seq_id_dst) {
|
| 246 |
+
return;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if (p0 < 0) {
|
| 250 |
+
p0 = 0;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
if (p1 < 0) {
|
| 254 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if (recurrent) {
|
| 258 |
+
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
| 259 |
+
llama_kv_cell & tail_src = cells[seq_id_src];
|
| 260 |
+
llama_kv_cell & tail_dst = cells[seq_id_dst];
|
| 261 |
+
if (tail_dst.tail >= 0) {
|
| 262 |
+
// clear destination seq_id if it wasn't empty
|
| 263 |
+
llama_kv_cell & cell_dst = cells[tail_dst.tail];
|
| 264 |
+
|
| 265 |
+
cell_dst.seq_id.erase(seq_id_dst);
|
| 266 |
+
tail_dst.tail = -1;
|
| 267 |
+
if (cell_dst.seq_id.empty()) {
|
| 268 |
+
cell_dst.pos = -1;
|
| 269 |
+
cell_dst.delta = -1;
|
| 270 |
+
cell_dst.src = -1;
|
| 271 |
+
used -= 1;
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
if (tail_src.tail >= 0) {
|
| 275 |
+
llama_kv_cell & cell_src = cells[tail_src.tail];
|
| 276 |
+
|
| 277 |
+
cell_src.seq_id.insert(seq_id_dst);
|
| 278 |
+
tail_dst.tail = tail_src.tail;
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
return;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// otherwise, this is the KV of a Transformer-like model
|
| 286 |
+
head = 0;
|
| 287 |
+
|
| 288 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 289 |
+
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
| 290 |
+
cells[i].seq_id.insert(seq_id_dst);
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
| 296 |
+
uint32_t new_head = size;
|
| 297 |
+
|
| 298 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 299 |
+
if (recurrent && (llama_seq_id) i != seq_id) {
|
| 300 |
+
cells[i].tail = -1;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
if (!cells[i].has_seq_id(seq_id)) {
|
| 304 |
+
if (cells[i].pos >= 0) {
|
| 305 |
+
used--;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
cells[i].pos = -1;
|
| 309 |
+
cells[i].src = -1;
|
| 310 |
+
cells[i].seq_id.clear();
|
| 311 |
+
|
| 312 |
+
if (new_head == size){
|
| 313 |
+
new_head = i;
|
| 314 |
+
}
|
| 315 |
+
} else {
|
| 316 |
+
cells[i].seq_id.clear();
|
| 317 |
+
cells[i].seq_id.insert(seq_id);
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 322 |
+
if (new_head != size && new_head < head) {
|
| 323 |
+
head = new_head;
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
| 328 |
+
if (delta == 0) {
|
| 329 |
+
return;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
uint32_t new_head = size;
|
| 333 |
+
|
| 334 |
+
if (p0 < 0) {
|
| 335 |
+
p0 = 0;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
if (p1 < 0) {
|
| 339 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// If there is no range then return early to avoid looping over the
|
| 343 |
+
if (p0 == p1) {
|
| 344 |
+
return;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
if (recurrent) {
|
| 348 |
+
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
| 349 |
+
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 350 |
+
const int32_t tail_id = cells[seq_id].tail;
|
| 351 |
+
if (tail_id >= 0) {
|
| 352 |
+
llama_kv_cell & cell = cells[tail_id];
|
| 353 |
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 354 |
+
cell.pos += delta;
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
return;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 362 |
+
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
| 363 |
+
has_shift = true;
|
| 364 |
+
cells[i].pos += delta;
|
| 365 |
+
cells[i].delta += delta;
|
| 366 |
+
|
| 367 |
+
if (cells[i].pos < 0) {
|
| 368 |
+
if (!cells[i].is_empty()) {
|
| 369 |
+
used--;
|
| 370 |
+
}
|
| 371 |
+
cells[i].pos = -1;
|
| 372 |
+
cells[i].seq_id.clear();
|
| 373 |
+
if (new_head == size) {
|
| 374 |
+
new_head = i;
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
// If we freed up a slot, set head to it so searching can start there.
|
| 381 |
+
// Otherwise we just start the next search from the beginning.
|
| 382 |
+
head = new_head != size ? new_head : 0;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 386 |
+
if (d == 1) {
|
| 387 |
+
return;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
if (p0 < 0) {
|
| 391 |
+
p0 = 0;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
if (p1 < 0) {
|
| 395 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
// If there is no range then return early to avoid looping over the cache.
|
| 399 |
+
if (p0 == p1) {
|
| 400 |
+
return;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
if (recurrent) {
|
| 404 |
+
// for Mamba-like or RWKV models, only the pos needs to be changed
|
| 405 |
+
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 406 |
+
const int32_t tail_id = cells[seq_id].tail;
|
| 407 |
+
if (tail_id >= 0) {
|
| 408 |
+
llama_kv_cell & cell = cells[tail_id];
|
| 409 |
+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 410 |
+
cell.pos /= d;
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
return;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 419 |
+
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
| 420 |
+
has_shift = true;
|
| 421 |
+
|
| 422 |
+
{
|
| 423 |
+
llama_pos p_old = cells[i].pos;
|
| 424 |
+
cells[i].pos /= d;
|
| 425 |
+
cells[i].delta += cells[i].pos - p_old;
|
| 426 |
+
}
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
| 432 |
+
llama_pos result = 0;
|
| 433 |
+
|
| 434 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 435 |
+
if (cells[i].has_seq_id(seq_id)) {
|
| 436 |
+
result = std::max(result, cells[i].pos);
|
| 437 |
+
}
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
return result;
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
void llama_kv_cache_unified::defrag() {
|
| 444 |
+
if (!recurrent) {
|
| 445 |
+
do_defrag = true;
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
void llama_kv_cache_unified::restore() {
|
| 450 |
+
if (pending.ranges.empty()) {
|
| 451 |
+
return;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
// TODO: tmp - move to llama_kv_cache_recurrent
|
| 455 |
+
if (recurrent) {
|
| 456 |
+
seq_rm(-1, -1, -1);
|
| 457 |
+
return;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
uint32_t new_head = size;
|
| 461 |
+
|
| 462 |
+
for (auto & range : pending.ranges) {
|
| 463 |
+
for (uint32_t i = range.c0; i < range.c1; ++i) {
|
| 464 |
+
cells[i].seq_id.clear();
|
| 465 |
+
|
| 466 |
+
// keep count of the number of used cells
|
| 467 |
+
if (cells[i].pos >= 0) {
|
| 468 |
+
used--;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
cells[i].pos = -1;
|
| 472 |
+
cells[i].src = -1;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
new_head = std::min(new_head, range.c0);
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
if (new_head != size && new_head < head) {
|
| 479 |
+
head = new_head;
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
void llama_kv_cache_unified::commit() {
|
| 484 |
+
// TODO: tmp - move to llama_kv_cache_recurrent
|
| 485 |
+
if (recurrent) {
|
| 486 |
+
return;
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
if (pending.ranges.empty()) {
|
| 490 |
+
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
|
| 491 |
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
|
| 492 |
+
return;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
pending.ranges.clear();
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
bool llama_kv_cache_unified::get_can_shift() const {
|
| 499 |
+
return can_shift;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
bool llama_kv_cache_unified::find_slot(
|
| 503 |
+
const llama_ubatch & ubatch) {
|
| 504 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 505 |
const uint32_t n_seqs = ubatch.n_seqs;
|
| 506 |
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 507 |
|
| 508 |
+
// if we have enough unused cells before the current head ->
|
| 509 |
+
// better to start searching from the beginning of the cache, hoping to fill it
|
| 510 |
+
if (head > used + 2*ubatch.n_tokens) {
|
| 511 |
+
head = 0;
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
if (recurrent) {
|
| 515 |
// For recurrent state architectures (like Mamba or RWKV),
|
| 516 |
// each cache cell can store the state for a whole sequence.
|
| 517 |
// A slot should be always be contiguous.
|
|
|
|
| 519 |
// can only process batches with an equal number of new tokens in each sequence
|
| 520 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 521 |
|
| 522 |
+
int32_t min = size - 1;
|
| 523 |
int32_t max = 0;
|
| 524 |
|
| 525 |
// everything should fit if all seq_ids are smaller than the max
|
|
|
|
| 528 |
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 529 |
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 530 |
|
| 531 |
+
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 532 |
// too big seq_id
|
| 533 |
// TODO: would it be possible to resize the cache instead?
|
| 534 |
+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
|
| 535 |
+
return false;
|
| 536 |
}
|
| 537 |
if (j > 0) {
|
| 538 |
+
llama_kv_cell & seq = cells[seq_id];
|
| 539 |
if (seq.tail >= 0) {
|
| 540 |
+
llama_kv_cell & cell = cells[seq.tail];
|
| 541 |
// clear cells from seq_ids that become shared
|
| 542 |
// (should not normally happen, but let's handle it anyway)
|
| 543 |
cell.seq_id.erase(seq_id);
|
|
|
|
| 545 |
if (cell.seq_id.empty()) {
|
| 546 |
cell.pos = -1;
|
| 547 |
cell.src = -1;
|
| 548 |
+
used -= 1;
|
| 549 |
}
|
| 550 |
}
|
| 551 |
}
|
|
|
|
| 555 |
#ifndef NDEBUG
|
| 556 |
{
|
| 557 |
std::vector<int32_t> tails_verif;
|
| 558 |
+
tails_verif.assign(size, -1);
|
| 559 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 560 |
+
llama_kv_cell & cell = cells[i];
|
| 561 |
for (llama_seq_id seq_id : cell.seq_id) {
|
| 562 |
if (tails_verif[seq_id] != -1) {
|
| 563 |
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
|
|
|
| 565 |
tails_verif[seq_id] = i;
|
| 566 |
}
|
| 567 |
}
|
| 568 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 569 |
+
if (tails_verif[i] != cells[i].tail) {
|
| 570 |
+
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
|
| 571 |
}
|
| 572 |
}
|
| 573 |
}
|
| 574 |
#endif
|
| 575 |
|
| 576 |
// find next empty cell
|
| 577 |
+
uint32_t next_empty_cell = head;
|
| 578 |
|
| 579 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 580 |
+
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 581 |
+
llama_kv_cell & cell = cells[next_empty_cell];
|
| 582 |
if (cell.is_empty()) { break; }
|
| 583 |
next_empty_cell += 1;
|
| 584 |
}
|
|
|
|
| 586 |
// find usable cell range
|
| 587 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 588 |
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 589 |
+
llama_kv_cell & seq_meta = cells[seq_id];
|
| 590 |
bool has_cell = false;
|
| 591 |
if (seq_meta.tail >= 0) {
|
| 592 |
+
llama_kv_cell & cell = cells[seq_meta.tail];
|
| 593 |
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 594 |
// does this seq_id "own" the cell?
|
| 595 |
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 596 |
}
|
| 597 |
if (!has_cell) {
|
| 598 |
+
llama_kv_cell & empty_cell = cells[next_empty_cell];
|
| 599 |
GGML_ASSERT(empty_cell.is_empty());
|
| 600 |
// copy old tail into the empty cell
|
| 601 |
if (seq_meta.tail >= 0) {
|
| 602 |
+
llama_kv_cell & orig_cell = cells[seq_meta.tail];
|
| 603 |
empty_cell.pos = orig_cell.pos;
|
| 604 |
empty_cell.src = orig_cell.src;
|
| 605 |
orig_cell.seq_id.erase(seq_id);
|
|
|
|
| 609 |
// find next empty cell
|
| 610 |
if (s + 1 < n_seqs) {
|
| 611 |
next_empty_cell += 1;
|
| 612 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 613 |
+
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 614 |
+
llama_kv_cell & cell = cells[next_empty_cell];
|
| 615 |
if (cell.is_empty()) { break; }
|
| 616 |
next_empty_cell += 1;
|
| 617 |
}
|
|
|
|
| 624 |
// gather and re-order
|
| 625 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 626 |
int32_t dst_id = s + min;
|
| 627 |
+
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
| 628 |
if (dst_id != src_id) {
|
| 629 |
+
llama_kv_cell & dst_cell = cells[dst_id];
|
| 630 |
+
llama_kv_cell & src_cell = cells[src_id];
|
| 631 |
|
| 632 |
std::swap(dst_cell.pos, src_cell.pos);
|
| 633 |
std::swap(dst_cell.src, src_cell.src);
|
|
|
|
| 635 |
|
| 636 |
// swap tails (assuming they NEVER overlap)
|
| 637 |
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
| 638 |
+
cells[seq_id].tail = src_id;
|
| 639 |
}
|
| 640 |
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
| 641 |
+
cells[seq_id].tail = dst_id;
|
| 642 |
}
|
| 643 |
}
|
| 644 |
}
|
|
|
|
| 647 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 648 |
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 649 |
int32_t cell_id = s + min;
|
| 650 |
+
llama_kv_cell & cell = cells[cell_id];
|
| 651 |
|
| 652 |
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 653 |
// What should happen when the pos backtracks or skips a value?
|
|
|
|
| 660 |
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
| 661 |
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
| 662 |
cell.seq_id.insert(seq_id);
|
| 663 |
+
cells[seq_id].tail = cell_id;
|
| 664 |
}
|
| 665 |
}
|
| 666 |
|
| 667 |
// allow getting the range of used cells, from head to head + n
|
| 668 |
+
head = min;
|
| 669 |
+
n = max - min + 1;
|
| 670 |
+
used = std::count_if(cells.begin(), cells.end(),
|
| 671 |
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
| 672 |
|
| 673 |
// sanity check
|
| 674 |
+
return n >= n_seqs;
|
| 675 |
}
|
| 676 |
+
|
| 677 |
// otherwise, one cell per token.
|
| 678 |
|
| 679 |
+
if (n_tokens > size) {
|
| 680 |
+
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
| 681 |
+
return false;
|
| 682 |
}
|
| 683 |
|
| 684 |
uint32_t n_tested = 0;
|
| 685 |
|
| 686 |
while (true) {
|
| 687 |
+
if (head + n_tokens > size) {
|
| 688 |
+
n_tested += size - head;
|
| 689 |
+
head = 0;
|
| 690 |
continue;
|
| 691 |
}
|
| 692 |
|
| 693 |
bool found = true;
|
| 694 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 695 |
+
if (cells[head + i].pos >= 0) {
|
| 696 |
found = false;
|
| 697 |
+
head += i + 1;
|
| 698 |
+
n_tested += i + 1;
|
| 699 |
break;
|
| 700 |
}
|
| 701 |
}
|
|
|
|
| 704 |
break;
|
| 705 |
}
|
| 706 |
|
| 707 |
+
if (n_tested >= size) {
|
| 708 |
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 709 |
+
return false;
|
| 710 |
}
|
| 711 |
}
|
| 712 |
|
| 713 |
for (uint32_t s = 0; s < n_seqs; s++) {
|
| 714 |
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
| 715 |
uint32_t k = s*n_seq_tokens + i;
|
| 716 |
+
cells[head + k].pos = ubatch.pos[k];
|
| 717 |
|
| 718 |
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
| 719 |
+
cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
| 720 |
}
|
| 721 |
}
|
| 722 |
}
|
| 723 |
|
| 724 |
+
used += n_tokens;
|
| 725 |
+
|
| 726 |
+
pending.ranges.push_back({head, head + n_tokens});
|
| 727 |
+
|
| 728 |
+
return true;
|
| 729 |
+
}
|
| 730 |
|
| 731 |
+
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
|
| 732 |
+
// the FA kernels require padding to avoid extra runtime boundary checks
|
| 733 |
+
return cparams.flash_attn ? 256u : 32u;
|
| 734 |
}
|
| 735 |
|
| 736 |
+
uint32_t llama_kv_cache_unified::cell_max() const {
|
| 737 |
+
for (uint32_t i = size; i > 0; --i) {
|
| 738 |
+
const llama_kv_cell & cell = cells[i - 1];
|
| 739 |
|
| 740 |
if (cell.pos >= 0 && !cell.is_empty()) {
|
| 741 |
return i;
|
|
|
|
| 745 |
return 0;
|
| 746 |
}
|
| 747 |
|
| 748 |
+
size_t llama_kv_cache_unified::size_k_bytes() const {
|
| 749 |
+
size_t size_k_bytes = 0;
|
| 750 |
+
|
| 751 |
+
for (const auto & k : k_l) {
|
| 752 |
+
size_k_bytes += ggml_nbytes(k);
|
|
|
|
| 753 |
}
|
|
|
|
|
|
|
| 754 |
|
| 755 |
+
return size_k_bytes;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
size_t llama_kv_cache_unified::size_v_bytes() const {
|
| 759 |
+
size_t size_v_bytes = 0;
|
| 760 |
+
|
| 761 |
+
for (const auto & v : v_l) {
|
| 762 |
+
size_v_bytes += ggml_nbytes(v);
|
| 763 |
}
|
| 764 |
+
|
| 765 |
+
return size_v_bytes;
|
| 766 |
}
|
| 767 |
|
| 768 |
+
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
| 769 |
+
const uint32_t n_layer = hparams.n_layer;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
+
const uint32_t n_kv = cell_max();
|
| 772 |
+
const uint32_t n_used = used;
|
| 773 |
|
| 774 |
+
assert(n_used <= n_kv);
|
| 775 |
+
|
| 776 |
+
//const int64_t t_start = ggml_time_us();
|
| 777 |
+
|
| 778 |
+
// number of cells moved
|
| 779 |
+
uint32_t n_moves = 0;
|
| 780 |
+
|
| 781 |
+
// each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
|
| 782 |
+
// - source view, destination view, copy operation
|
| 783 |
+
// - x2 for keys and values
|
| 784 |
+
//const uint32_t max_moves = max_nodes()/(6*n_layer);
|
| 785 |
+
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
| 786 |
+
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
| 787 |
+
|
| 788 |
+
// determine which KV cells to move where
|
| 789 |
+
//
|
| 790 |
+
// cell i moves to ids[i]
|
| 791 |
+
//
|
| 792 |
+
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
| 793 |
+
//
|
| 794 |
+
auto & ids = defrag_info.ids;
|
| 795 |
+
|
| 796 |
+
ids.clear();
|
| 797 |
+
ids.resize(n_kv, n_kv);
|
| 798 |
+
|
| 799 |
+
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
| 800 |
+
const auto & cell0 = cells[i0];
|
| 801 |
+
|
| 802 |
+
if (!cell0.is_empty()) {
|
| 803 |
+
ids[i0] = i0;
|
| 804 |
+
|
| 805 |
+
continue;
|
| 806 |
}
|
| 807 |
+
|
| 808 |
+
// found a hole - fill it with data from the end of the cache
|
| 809 |
+
|
| 810 |
+
uint32_t nh = 1;
|
| 811 |
+
|
| 812 |
+
// determine the size of the hole
|
| 813 |
+
while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
|
| 814 |
+
nh++;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
}
|
|
|
|
| 816 |
|
| 817 |
+
uint32_t nf = 0;
|
| 818 |
+
uint32_t is = n_kv - 1;
|
| 819 |
+
|
| 820 |
+
// starting from the end, find nh non-empty cells
|
| 821 |
+
for (; is > i0; --is) {
|
| 822 |
+
const auto & cell1 = cells[is];
|
| 823 |
+
|
| 824 |
+
if (cell1.is_empty() || ids[is] != n_kv) {
|
| 825 |
continue;
|
| 826 |
}
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
+
// non-empty cell which is not yet moved
|
| 829 |
+
nf++;
|
| 830 |
+
|
| 831 |
+
if (nf == nh) {
|
| 832 |
+
break;
|
| 833 |
}
|
| 834 |
}
|
|
|
|
| 835 |
|
| 836 |
+
// this can only happen if `n_used` is not accurate, which would be a bug
|
| 837 |
+
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
| 838 |
|
| 839 |
+
nf = 0;
|
|
|
|
| 840 |
|
| 841 |
+
uint32_t i1 = is;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
|
| 843 |
+
// are we moving a continuous block of memory?
|
| 844 |
+
bool cont = false;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 845 |
|
| 846 |
+
// should we stop searching for the next move?
|
| 847 |
+
bool stop = false;
|
| 848 |
+
|
| 849 |
+
// go back and move the nf cells to the hole
|
| 850 |
+
for (; i1 < n_kv; ++i1) {
|
| 851 |
+
auto & cell1 = cells[i1];
|
| 852 |
+
|
| 853 |
+
if (cell1.is_empty() || ids[i1] != n_kv) {
|
| 854 |
+
if (n_moves == max_moves) {
|
| 855 |
+
stop = true;
|
| 856 |
+
break;
|
| 857 |
}
|
| 858 |
+
|
| 859 |
+
cont = false;
|
| 860 |
+
continue;
|
| 861 |
}
|
|
|
|
|
|
|
| 862 |
|
| 863 |
+
// this cell goes to (i0 + nf)
|
| 864 |
+
ids[i1] = i0 + nf;
|
| 865 |
+
|
| 866 |
+
// move the cell meta data
|
| 867 |
+
cells[i0 + nf] = cell1;
|
| 868 |
+
|
| 869 |
+
// clear the old cell and move the head there
|
| 870 |
+
cell1 = llama_kv_cell();
|
| 871 |
+
head = n_used;
|
| 872 |
+
|
| 873 |
+
if (!cont) {
|
| 874 |
+
n_moves++;
|
| 875 |
+
cont = true;
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
nf++;
|
| 879 |
+
|
| 880 |
+
if (nf == nh) {
|
| 881 |
+
break;
|
| 882 |
}
|
| 883 |
}
|
| 884 |
|
| 885 |
+
if (stop || n_moves == max_moves) {
|
| 886 |
+
break;
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
| 890 |
+
|
| 891 |
+
i0 += nh - 1;
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
if (n_moves == 0) {
|
| 895 |
+
return false;
|
| 896 |
}
|
|
|
|
| 897 |
|
| 898 |
+
LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
| 899 |
|
| 900 |
+
LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
| 901 |
+
|
| 902 |
+
return true;
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 906 |
+
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 907 |
+
uint32_t cell_count = 0;
|
| 908 |
+
|
| 909 |
+
// Count the number of cells with the specified seq_id
|
| 910 |
+
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 911 |
+
uint32_t cell_range_begin = size;
|
| 912 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 913 |
+
const auto & cell = cells[i];
|
| 914 |
+
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
| 915 |
+
++cell_count;
|
| 916 |
+
if (cell_range_begin == size) {
|
| 917 |
+
cell_range_begin = i;
|
| 918 |
+
}
|
| 919 |
+
} else {
|
| 920 |
+
if (cell_range_begin != size) {
|
| 921 |
+
cell_ranges.emplace_back(cell_range_begin, i);
|
| 922 |
+
cell_range_begin = size;
|
| 923 |
+
}
|
| 924 |
}
|
| 925 |
}
|
| 926 |
+
if (cell_range_begin != size) {
|
| 927 |
+
cell_ranges.emplace_back(cell_range_begin, size);
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 931 |
+
uint32_t cell_count_check = 0;
|
| 932 |
+
for (const auto & range : cell_ranges) {
|
| 933 |
+
cell_count_check += range.second - range.first;
|
| 934 |
+
}
|
| 935 |
+
GGML_ASSERT(cell_count == cell_count_check);
|
| 936 |
+
|
| 937 |
+
io.write(&cell_count, sizeof(cell_count));
|
| 938 |
+
|
| 939 |
+
state_write_meta(io, cell_ranges, seq_id);
|
| 940 |
+
state_write_data(io, cell_ranges);
|
| 941 |
}
|
| 942 |
|
| 943 |
+
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 944 |
+
uint32_t cell_count;
|
| 945 |
+
io.read_to(&cell_count, sizeof(cell_count));
|
| 946 |
|
| 947 |
+
bool res = true;
|
| 948 |
+
res = res && state_read_meta(io, cell_count, seq_id);
|
| 949 |
+
res = res && state_read_data(io, cell_count);
|
| 950 |
+
|
| 951 |
+
if (!res) {
|
| 952 |
+
if (seq_id == -1) {
|
| 953 |
+
clear();
|
|
|
|
|
|
|
|
|
|
| 954 |
} else {
|
| 955 |
+
seq_rm(seq_id, -1, -1);
|
|
|
|
| 956 |
}
|
| 957 |
+
throw std::runtime_error("failed to restore kv cache");
|
| 958 |
}
|
|
|
|
|
|
|
|
|
|
| 959 |
}
|
| 960 |
|
| 961 |
+
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 962 |
+
for (const auto & range : cell_ranges) {
|
| 963 |
+
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 964 |
+
const auto & cell = cells[i];
|
| 965 |
+
const llama_pos pos = cell.pos;
|
| 966 |
+
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
|
|
|
| 967 |
|
| 968 |
+
io.write(&pos, sizeof(pos));
|
| 969 |
+
io.write(&n_seq_id, sizeof(n_seq_id));
|
|
|
|
|
|
|
| 970 |
|
| 971 |
+
if (n_seq_id) {
|
| 972 |
+
for (auto seq_id : cell.seq_id) {
|
| 973 |
+
io.write(&seq_id, sizeof(seq_id));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 974 |
}
|
| 975 |
}
|
| 976 |
}
|
|
|
|
| 977 |
}
|
| 978 |
+
}
|
| 979 |
|
| 980 |
+
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 981 |
+
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 982 |
+
const uint32_t n_layer = hparams.n_layer;
|
|
|
|
|
|
|
| 983 |
|
| 984 |
+
io.write(&v_trans, sizeof(v_trans));
|
| 985 |
+
io.write(&n_layer, sizeof(n_layer));
|
| 986 |
+
|
| 987 |
+
std::vector<uint8_t> tmp_buf;
|
| 988 |
+
|
| 989 |
+
// Iterate and write all the keys first, each row is a cell
|
| 990 |
+
// Get whole range at a time
|
| 991 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 992 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 993 |
+
|
| 994 |
+
// Write key type
|
| 995 |
+
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
| 996 |
+
io.write(&k_type_i, sizeof(k_type_i));
|
| 997 |
+
|
| 998 |
+
// Write row size of key
|
| 999 |
+
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 1000 |
+
io.write(&k_size_row, sizeof(k_size_row));
|
| 1001 |
+
|
| 1002 |
+
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1003 |
+
for (const auto & range : cell_ranges) {
|
| 1004 |
+
const size_t range_size = range.second - range.first;
|
| 1005 |
+
const size_t buf_size = range_size * k_size_row;
|
| 1006 |
+
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
| 1007 |
}
|
| 1008 |
}
|
| 1009 |
|
| 1010 |
+
if (!v_trans) {
|
| 1011 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1012 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
|
|
| 1013 |
|
| 1014 |
+
// Write value type
|
| 1015 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 1016 |
+
io.write(&v_type_i, sizeof(v_type_i));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
|
| 1018 |
+
// Write row size of value
|
| 1019 |
+
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
| 1020 |
+
io.write(&v_size_row, sizeof(v_size_row));
|
| 1021 |
+
|
| 1022 |
+
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1023 |
+
for (const auto & range : cell_ranges) {
|
| 1024 |
+
const size_t range_size = range.second - range.first;
|
| 1025 |
+
const size_t buf_size = range_size * v_size_row;
|
| 1026 |
+
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
| 1027 |
+
}
|
| 1028 |
+
}
|
| 1029 |
+
} else {
|
| 1030 |
+
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 1031 |
+
const uint32_t kv_size = size;
|
| 1032 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1033 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1034 |
+
|
| 1035 |
+
// Write value type
|
| 1036 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 1037 |
+
io.write(&v_type_i, sizeof(v_type_i));
|
| 1038 |
+
|
| 1039 |
+
// Write element size
|
| 1040 |
+
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 1041 |
+
io.write(&v_size_el, sizeof(v_size_el));
|
| 1042 |
+
|
| 1043 |
+
// Write GQA embedding size
|
| 1044 |
+
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
| 1045 |
+
|
| 1046 |
+
// For each row, we get the element values of each cell
|
| 1047 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1048 |
+
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 1049 |
+
for (const auto & range : cell_ranges) {
|
| 1050 |
+
const size_t range_size = range.second - range.first;
|
| 1051 |
+
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 1052 |
+
const size_t buf_size = range_size * v_size_el;
|
| 1053 |
+
io.write_tensor(v_l[il], src_offset, buf_size);
|
| 1054 |
}
|
| 1055 |
}
|
| 1056 |
}
|
|
|
|
| 1057 |
}
|
| 1058 |
+
}
|
| 1059 |
|
| 1060 |
+
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 1061 |
+
if (dest_seq_id != -1) {
|
| 1062 |
+
// single sequence
|
| 1063 |
|
| 1064 |
+
seq_rm(dest_seq_id, -1, -1);
|
| 1065 |
+
|
| 1066 |
+
llama_sbatch sbatch;
|
| 1067 |
+
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1068 |
+
|
| 1069 |
+
batch.n_tokens = cell_count;
|
| 1070 |
+
batch.n_seq_tokens = cell_count;
|
| 1071 |
+
batch.n_seqs = 1;
|
| 1072 |
+
|
| 1073 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1074 |
+
llama_pos pos;
|
| 1075 |
+
uint32_t n_seq_id;
|
| 1076 |
+
|
| 1077 |
+
io.read_to(&pos, sizeof(pos));
|
| 1078 |
+
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1079 |
+
|
| 1080 |
+
if (n_seq_id != 0) {
|
| 1081 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 1082 |
+
return false;
|
| 1083 |
}
|
| 1084 |
+
|
| 1085 |
+
batch.pos[i] = pos;
|
| 1086 |
+
}
|
| 1087 |
+
batch.n_seq_id[0] = 1;
|
| 1088 |
+
batch.seq_id[0] = &dest_seq_id;
|
| 1089 |
+
if (!find_slot(batch)) {
|
| 1090 |
+
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1091 |
+
return false;
|
| 1092 |
+
}
|
| 1093 |
+
commit();
|
| 1094 |
+
|
| 1095 |
+
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1096 |
+
// Assume that this is one contiguous block of cells
|
| 1097 |
+
GGML_ASSERT(head + cell_count <= size);
|
| 1098 |
+
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
| 1099 |
+
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
| 1100 |
+
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
| 1101 |
+
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
| 1102 |
+
} else {
|
| 1103 |
+
// whole KV cache restore
|
| 1104 |
+
|
| 1105 |
+
if (cell_count > size) {
|
| 1106 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 1107 |
+
return false;
|
| 1108 |
}
|
|
|
|
|
|
|
| 1109 |
|
| 1110 |
+
clear();
|
| 1111 |
+
|
| 1112 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1113 |
+
llama_kv_cell & cell = cells[i];
|
| 1114 |
+
|
| 1115 |
+
llama_pos pos;
|
| 1116 |
+
uint32_t n_seq_id;
|
| 1117 |
+
|
| 1118 |
+
io.read_to(&pos, sizeof(pos));
|
| 1119 |
+
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1120 |
+
|
| 1121 |
+
cell.pos = pos;
|
| 1122 |
|
| 1123 |
+
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 1124 |
+
llama_seq_id seq_id;
|
| 1125 |
+
io.read_to(&seq_id, sizeof(seq_id));
|
| 1126 |
+
|
| 1127 |
+
// TODO: llama_kv_cache_unified should have a notion of max sequences
|
| 1128 |
+
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
| 1129 |
+
if (seq_id < 0) {
|
| 1130 |
+
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
| 1131 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
|
| 1132 |
+
return false;
|
| 1133 |
+
}
|
| 1134 |
+
|
| 1135 |
+
cell.seq_id.insert(seq_id);
|
| 1136 |
+
|
| 1137 |
+
if (recurrent) {
|
| 1138 |
+
int32_t & tail = cells[seq_id].tail;
|
| 1139 |
+
if (tail != -1) {
|
| 1140 |
+
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
| 1141 |
+
return false;
|
| 1142 |
+
}
|
| 1143 |
+
tail = i;
|
| 1144 |
+
}
|
| 1145 |
+
}
|
| 1146 |
}
|
| 1147 |
+
|
| 1148 |
+
head = 0;
|
| 1149 |
+
used = cell_count;
|
| 1150 |
}
|
| 1151 |
|
| 1152 |
+
if (recurrent) {
|
| 1153 |
+
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1154 |
+
uint32_t cell_id = head + i;
|
| 1155 |
+
// make sure the recurrent states will keep their restored state
|
| 1156 |
+
cells[cell_id].src = cell_id;
|
| 1157 |
+
}
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
return true;
|
| 1161 |
}
|
| 1162 |
|
| 1163 |
+
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 1164 |
+
uint32_t v_trans;
|
| 1165 |
+
uint32_t n_layer;
|
| 1166 |
+
io.read_to(&v_trans, sizeof(v_trans));
|
| 1167 |
+
io.read_to(&n_layer, sizeof(n_layer));
|
| 1168 |
+
|
| 1169 |
+
if (n_layer != hparams.n_layer) {
|
| 1170 |
+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
| 1171 |
+
return false;
|
| 1172 |
+
}
|
| 1173 |
+
if (cell_count > size) {
|
| 1174 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
| 1175 |
+
return false;
|
| 1176 |
+
}
|
| 1177 |
+
if (v_trans != (bool) v_trans) {
|
| 1178 |
+
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
| 1179 |
+
return false;
|
| 1180 |
}
|
|
|
|
| 1181 |
|
| 1182 |
+
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 1183 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1184 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1185 |
|
| 1186 |
+
// Read type of key
|
| 1187 |
+
int32_t k_type_i_ref;
|
| 1188 |
+
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1189 |
+
const int32_t k_type_i = (int32_t) k_l[il]->type;
|
| 1190 |
+
if (k_type_i != k_type_i_ref) {
|
| 1191 |
+
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1192 |
+
return false;
|
| 1193 |
+
}
|
| 1194 |
+
|
| 1195 |
+
// Read row size of key
|
| 1196 |
+
uint64_t k_size_row_ref;
|
| 1197 |
+
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1198 |
+
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
| 1199 |
+
if (k_size_row != k_size_row_ref) {
|
| 1200 |
+
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1201 |
+
return false;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
if (cell_count) {
|
| 1205 |
+
// Read and set the keys for the whole cell range
|
| 1206 |
+
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 1207 |
+
}
|
| 1208 |
}
|
| 1209 |
|
| 1210 |
+
if (!v_trans) {
|
| 1211 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1212 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1213 |
|
| 1214 |
+
// Read type of value
|
| 1215 |
+
int32_t v_type_i_ref;
|
| 1216 |
+
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1217 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 1218 |
+
if (v_type_i != v_type_i_ref) {
|
| 1219 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1220 |
+
return false;
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
// Read row size of value
|
| 1224 |
+
uint64_t v_size_row_ref;
|
| 1225 |
+
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1226 |
+
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
| 1227 |
+
if (v_size_row != v_size_row_ref) {
|
| 1228 |
+
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1229 |
+
return false;
|
| 1230 |
+
}
|
| 1231 |
|
| 1232 |
+
if (cell_count) {
|
| 1233 |
+
// Read and set the values for the whole cell range
|
| 1234 |
+
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 1235 |
+
}
|
| 1236 |
+
}
|
| 1237 |
+
} else {
|
| 1238 |
+
// For each layer, read the values for each cell (transposed)
|
| 1239 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1240 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1241 |
+
|
| 1242 |
+
// Read type of value
|
| 1243 |
+
int32_t v_type_i_ref;
|
| 1244 |
+
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1245 |
+
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
| 1246 |
+
if (v_type_i != v_type_i_ref) {
|
| 1247 |
+
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1248 |
+
return false;
|
| 1249 |
+
}
|
| 1250 |
+
|
| 1251 |
+
// Read element size of value
|
| 1252 |
+
uint32_t v_size_el_ref;
|
| 1253 |
+
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1254 |
+
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
| 1255 |
+
if (v_size_el != v_size_el_ref) {
|
| 1256 |
+
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1257 |
+
return false;
|
| 1258 |
+
}
|
| 1259 |
+
|
| 1260 |
+
// Read GQA embedding size
|
| 1261 |
+
uint32_t n_embd_v_gqa_ref;
|
| 1262 |
+
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
| 1263 |
+
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
| 1264 |
+
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
| 1265 |
+
return false;
|
| 1266 |
+
}
|
| 1267 |
+
|
| 1268 |
+
if (cell_count) {
|
| 1269 |
+
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1270 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1271 |
+
const size_t dst_offset = (head + j * size) * v_size_el;
|
| 1272 |
+
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 1273 |
+
}
|
| 1274 |
+
}
|
| 1275 |
+
}
|
| 1276 |
+
}
|
| 1277 |
+
|
| 1278 |
+
return true;
|
| 1279 |
}
|
| 1280 |
|
| 1281 |
//
|
| 1282 |
// kv cache view
|
| 1283 |
//
|
| 1284 |
|
| 1285 |
+
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
|
| 1286 |
+
llama_kv_cache_view result = {
|
| 1287 |
/*.n_cells = */ 0,
|
| 1288 |
/*.n_seq_max = */ n_seq_max,
|
| 1289 |
/*.token_count = */ 0,
|
| 1290 |
+
/*.used_cells = */ kv.get_used_cells(),
|
| 1291 |
/*.max_contiguous = */ 0,
|
| 1292 |
/*.max_contiguous_idx = */ -1,
|
| 1293 |
/*.cells = */ nullptr,
|
|
|
|
| 1297 |
return result;
|
| 1298 |
}
|
| 1299 |
|
| 1300 |
+
void llama_kv_cache_view_free(llama_kv_cache_view * view) {
|
| 1301 |
if (view->cells != nullptr) {
|
| 1302 |
free(view->cells);
|
| 1303 |
view->cells = nullptr;
|
|
|
|
| 1308 |
}
|
| 1309 |
}
|
| 1310 |
|
| 1311 |
+
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
|
| 1312 |
+
// TODO: rework this in the future, for now quick hack
|
| 1313 |
+
const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
|
| 1314 |
+
if (kvu == nullptr) {
|
| 1315 |
+
LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
|
| 1316 |
+
return;
|
| 1317 |
+
}
|
| 1318 |
+
|
| 1319 |
+
if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
|
| 1320 |
+
view->n_cells = int32_t(kvu->size);
|
| 1321 |
+
void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
|
| 1322 |
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
| 1323 |
+
view->cells = (llama_kv_cache_view_cell *)p;
|
| 1324 |
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
|
| 1325 |
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
| 1326 |
view->cells_sequences = (llama_seq_id *)p;
|
| 1327 |
}
|
| 1328 |
|
| 1329 |
+
const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
|
| 1330 |
llama_kv_cache_view_cell * c_curr = view->cells;
|
| 1331 |
llama_seq_id * cs_curr = view->cells_sequences;
|
| 1332 |
int32_t used_cells = 0;
|
|
|
|
| 1335 |
uint32_t max_contig = 0;
|
| 1336 |
int32_t max_contig_idx = -1;
|
| 1337 |
|
| 1338 |
+
for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
| 1339 |
const size_t curr_size = kv_cells[i].seq_id.size();
|
| 1340 |
token_count += curr_size;
|
| 1341 |
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
|
|
|
| 1373 |
view->max_contiguous_idx = max_contig_idx;
|
| 1374 |
view->token_count = token_count;
|
| 1375 |
view->used_cells = used_cells;
|
| 1376 |
+
if (uint32_t(used_cells) != kvu->used) {
|
| 1377 |
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
| 1378 |
+
__func__, kvu->used, used_cells);
|
| 1379 |
}
|
| 1380 |
}
|
examples/talk-llama/llama-kv-cache.h
CHANGED
|
@@ -1,15 +1,51 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "llama.h"
|
|
|
|
|
|
|
| 4 |
|
| 5 |
#include "ggml-cpp.h"
|
| 6 |
|
|
|
|
| 7 |
#include <set>
|
| 8 |
#include <vector>
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
struct llama_kv_cell {
|
| 11 |
llama_pos pos = -1;
|
| 12 |
-
llama_pos delta =
|
| 13 |
int32_t src = -1; // used by recurrent state models to copy states
|
| 14 |
int32_t tail = -1;
|
| 15 |
|
|
@@ -29,190 +65,149 @@ struct llama_kv_cell {
|
|
| 29 |
};
|
| 30 |
|
| 31 |
// ring-buffer of cached KV data
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
| 50 |
-
ggml_type type_v = GGML_TYPE_F16;
|
| 51 |
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
|
| 66 |
-
|
| 67 |
-
}
|
| 68 |
|
| 69 |
-
|
| 70 |
-
llama_pos max_pos() const {
|
| 71 |
-
llama_pos max_pos = -1;
|
| 72 |
-
for (const auto & cell : cells) {
|
| 73 |
-
max_pos = std::max(max_pos, cell.pos);
|
| 74 |
-
}
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
//
|
| 81 |
-
|
| 82 |
-
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
| 83 |
-
bool found = false; // the slot was found
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
-
//
|
| 92 |
-
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
const llama_cparams & cparams,
|
| 98 |
-
ggml_type type_k,
|
| 99 |
-
ggml_type type_v,
|
| 100 |
-
uint32_t kv_size,
|
| 101 |
-
bool offload);
|
| 102 |
|
| 103 |
-
//
|
| 104 |
-
|
| 105 |
-
// returns a structure holding information about the slot found
|
| 106 |
-
// Note: On success, it's important that cache.head points
|
| 107 |
-
// to the first cell of the slot.
|
| 108 |
-
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
| 109 |
-
struct llama_kv_cache & cache,
|
| 110 |
-
const struct llama_ubatch & batch);
|
| 111 |
|
| 112 |
-
//
|
| 113 |
-
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
llama_pos p1);
|
| 122 |
|
| 123 |
-
|
| 124 |
-
struct llama_kv_cache & cache,
|
| 125 |
-
llama_seq_id seq_id_src,
|
| 126 |
-
llama_seq_id seq_id_dst,
|
| 127 |
-
llama_pos p0,
|
| 128 |
-
llama_pos p1);
|
| 129 |
|
| 130 |
-
void
|
| 131 |
-
|
| 132 |
-
llama_seq_id seq_id);
|
| 133 |
|
| 134 |
-
|
| 135 |
-
struct llama_kv_cache & cache,
|
| 136 |
-
llama_seq_id seq_id,
|
| 137 |
-
llama_pos p0,
|
| 138 |
-
llama_pos p1,
|
| 139 |
-
llama_pos delta);
|
| 140 |
|
| 141 |
-
|
| 142 |
-
struct llama_kv_cache & cache,
|
| 143 |
-
llama_seq_id seq_id,
|
| 144 |
-
llama_pos p0,
|
| 145 |
-
llama_pos p1,
|
| 146 |
-
int d);
|
| 147 |
|
| 148 |
-
|
| 149 |
-
struct llama_kv_cache & cache,
|
| 150 |
-
llama_seq_id seq_id);
|
| 151 |
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
-
|
|
|
|
| 155 |
|
| 156 |
-
|
|
|
|
| 157 |
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
//
|
| 161 |
-
|
| 162 |
-
//
|
| 163 |
|
| 164 |
-
|
| 165 |
|
| 166 |
-
|
|
|
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
struct llama_kv_slot_restorer {
|
| 175 |
-
struct llama_kv_cache_state {
|
| 176 |
-
uint32_t head = 0;
|
| 177 |
-
uint32_t n = 0;
|
| 178 |
-
} old_state;
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
|
| 183 |
|
| 184 |
-
bool
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
do_restore = true;
|
| 195 |
-
if (slot.boundaries.first != slot.boundaries.second) {
|
| 196 |
-
slot_boundaries.push_back(slot.boundaries);
|
| 197 |
-
}
|
| 198 |
-
}
|
| 199 |
-
}
|
| 200 |
|
| 201 |
-
|
| 202 |
-
// and rollback changes from all llama_kv_cache_find_slot calls
|
| 203 |
-
void restore(struct llama_kv_cache & cache) {
|
| 204 |
-
if (do_restore) {
|
| 205 |
-
cache.head = old_state.head;
|
| 206 |
-
cache.n = old_state.n;
|
| 207 |
-
|
| 208 |
-
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
| 209 |
-
llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
| 210 |
-
} else {
|
| 211 |
-
for (auto & slot : slot_boundaries) {
|
| 212 |
-
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
|
| 213 |
-
}
|
| 214 |
-
}
|
| 215 |
-
}
|
| 216 |
-
}
|
| 217 |
-
};
|
| 218 |
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
+
#include "llama-io.h"
|
| 5 |
+
#include "llama-memory.h"
|
| 6 |
|
| 7 |
#include "ggml-cpp.h"
|
| 8 |
|
| 9 |
+
#include <functional>
|
| 10 |
#include <set>
|
| 11 |
#include <vector>
|
| 12 |
|
| 13 |
+
struct llama_cparams;
|
| 14 |
+
struct llama_hparams;
|
| 15 |
+
struct llama_ubatch;
|
| 16 |
+
|
| 17 |
+
struct llama_kv_cache : public llama_memory_i {
|
| 18 |
+
using llama_memory_i::llama_memory_i;
|
| 19 |
+
|
| 20 |
+
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
| 21 |
+
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
| 22 |
+
|
| 23 |
+
virtual int32_t get_n_tokens() const = 0;
|
| 24 |
+
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
| 25 |
+
|
| 26 |
+
virtual bool get_can_shift() const = 0;
|
| 27 |
+
|
| 28 |
+
bool get_can_edit() const override { return get_can_shift(); }
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
struct llama_kv_cache_guard {
|
| 32 |
+
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
| 33 |
+
|
| 34 |
+
~llama_kv_cache_guard() {
|
| 35 |
+
kv->restore();
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
void commit() {
|
| 39 |
+
kv->commit();
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
private:
|
| 43 |
+
llama_kv_cache * kv;
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
struct llama_kv_cell {
|
| 47 |
llama_pos pos = -1;
|
| 48 |
+
llama_pos delta = 0;
|
| 49 |
int32_t src = -1; // used by recurrent state models to copy states
|
| 50 |
int32_t tail = -1;
|
| 51 |
|
|
|
|
| 65 |
};
|
| 66 |
|
| 67 |
// ring-buffer of cached KV data
|
| 68 |
+
// TODO: pimpl
|
| 69 |
+
// TODO: add notion of max sequences
|
| 70 |
+
class llama_kv_cache_unified : public llama_kv_cache {
|
| 71 |
+
public:
|
| 72 |
+
// can be used to query data from the model if needed
|
| 73 |
+
struct callbacks {
|
| 74 |
+
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
| 75 |
+
};
|
| 76 |
+
|
| 77 |
+
llama_kv_cache_unified(
|
| 78 |
+
const llama_hparams & hparams,
|
| 79 |
+
callbacks cbs);
|
| 80 |
+
|
| 81 |
+
virtual ~llama_kv_cache_unified() = default;
|
| 82 |
+
|
| 83 |
+
// TODO: become constructor
|
| 84 |
+
bool init(
|
| 85 |
+
const llama_model & model, // TODO: do not reference the model
|
| 86 |
+
const llama_cparams & cparams,
|
| 87 |
+
ggml_type type_k,
|
| 88 |
+
ggml_type type_v,
|
| 89 |
+
uint32_t kv_size,
|
| 90 |
+
bool offload);
|
| 91 |
|
| 92 |
+
int32_t get_n_tokens() const override;
|
| 93 |
+
int32_t get_used_cells() const override;
|
| 94 |
|
| 95 |
+
size_t total_size() const;
|
|
|
|
| 96 |
|
| 97 |
+
// TODO: better data structures to reduce the cost of this operation
|
| 98 |
+
llama_pos pos_max() const;
|
| 99 |
|
| 100 |
+
void clear() override;
|
| 101 |
+
void defrag() override;
|
| 102 |
|
| 103 |
+
virtual void restore() override;
|
| 104 |
+
virtual void commit() override;
|
| 105 |
|
| 106 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 107 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 108 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 109 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
| 110 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 111 |
|
| 112 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
|
|
| 113 |
|
| 114 |
+
bool get_can_shift() const override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
// find an empty slot of size "n_tokens" in the cache
|
| 117 |
+
// updates the cache head
|
| 118 |
+
// Note: On success, it's important that cache.head points
|
| 119 |
+
// to the first cell of the slot.
|
| 120 |
+
bool find_slot(const llama_ubatch & batch);
|
| 121 |
|
| 122 |
+
// TODO: maybe not needed
|
| 123 |
+
uint32_t get_padding(const llama_cparams & cparams) const;
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
// find how many cells are currently in use
|
| 126 |
+
uint32_t cell_max() const;
|
| 127 |
|
| 128 |
+
size_t size_k_bytes() const;
|
| 129 |
+
size_t size_v_bytes() const;
|
| 130 |
|
| 131 |
+
// defrag
|
|
|
|
| 132 |
|
| 133 |
+
struct {
|
| 134 |
+
std::vector<uint32_t> ids;
|
| 135 |
+
} defrag_info;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
// return true if cells have been moved
|
| 138 |
+
bool defrag_prepare(int32_t n_max_nodes);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
// commit/restore cache
|
|
|
|
| 141 |
|
| 142 |
+
struct slot_range {
|
| 143 |
+
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
| 144 |
+
uint32_t c1 = 0;
|
| 145 |
+
};
|
| 146 |
|
| 147 |
+
// pending cell updates that are not yet committed
|
| 148 |
+
struct {
|
| 149 |
+
std::vector<slot_range> ranges;
|
| 150 |
+
} pending;
|
|
|
|
| 151 |
|
| 152 |
+
// state write/load
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
| 155 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
|
|
|
| 156 |
|
| 157 |
+
// members
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
const llama_hparams & hparams;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
callbacks cbs;
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
bool has_shift = false;
|
| 164 |
+
bool do_defrag = false;
|
| 165 |
|
| 166 |
+
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
| 167 |
+
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
| 168 |
|
| 169 |
+
bool v_trans = true; // the value tensor is transposed
|
| 170 |
+
bool can_shift = false;
|
| 171 |
|
| 172 |
+
// Note: The value of head isn't only used to optimize searching
|
| 173 |
+
// for a free KV slot. llama_decode_impl also uses it, so it
|
| 174 |
+
// cannot be freely changed after a slot has been allocated.
|
| 175 |
+
uint32_t head = 0;
|
| 176 |
+
uint32_t size = 0;
|
| 177 |
+
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
| 178 |
|
| 179 |
+
// computed before each graph build
|
| 180 |
+
uint32_t n = 0;
|
|
|
|
| 181 |
|
| 182 |
+
std::vector<llama_kv_cell> cells;
|
| 183 |
|
| 184 |
+
std::vector<ggml_tensor *> k_l; // per layer
|
| 185 |
+
std::vector<ggml_tensor *> v_l;
|
| 186 |
|
| 187 |
+
private:
|
| 188 |
+
ggml_type type_k = GGML_TYPE_F16;
|
| 189 |
+
ggml_type type_v = GGML_TYPE_F16;
|
| 190 |
|
| 191 |
+
std::vector<ggml_context_ptr> ctxs;
|
| 192 |
+
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 195 |
+
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
|
|
| 196 |
|
| 197 |
+
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 198 |
+
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 199 |
+
};
|
| 200 |
|
| 201 |
+
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
| 202 |
+
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
| 203 |
+
//public:
|
| 204 |
+
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
| 205 |
+
//};
|
| 206 |
|
| 207 |
+
//
|
| 208 |
+
// kv cache view
|
| 209 |
+
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
|
examples/talk-llama/llama-memory.cpp
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include "llama-memory.h"
|
examples/talk-llama/llama-memory.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
|
| 5 |
+
// general concept of LLM memory
|
| 6 |
+
// the KV cache is a type of LLM memory, but there can be other types
|
| 7 |
+
class llama_memory_i {
|
| 8 |
+
public:
|
| 9 |
+
virtual void clear() = 0;
|
| 10 |
+
virtual void defrag() = 0;
|
| 11 |
+
|
| 12 |
+
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
| 13 |
+
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
| 14 |
+
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
| 15 |
+
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
| 16 |
+
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
| 17 |
+
|
| 18 |
+
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
| 19 |
+
|
| 20 |
+
virtual bool get_can_edit() const = 0;
|
| 21 |
+
};
|
examples/talk-llama/llama-mmap.cpp
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
#include <climits>
|
| 9 |
#include <stdexcept>
|
| 10 |
#include <cerrno>
|
|
|
|
| 11 |
|
| 12 |
#ifdef __has_include
|
| 13 |
#if __has_include(<unistd.h>)
|
|
@@ -34,6 +35,10 @@
|
|
| 34 |
#include <io.h>
|
| 35 |
#endif
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
// TODO: consider moving to llama-impl.h if needed in more places
|
| 38 |
#if defined(_WIN32)
|
| 39 |
static std::string llama_format_win_err(DWORD err) {
|
|
@@ -471,7 +476,11 @@ struct llama_mlock::impl {
|
|
| 471 |
|
| 472 |
char* errmsg = std::strerror(errno);
|
| 473 |
bool suggest = (errno == ENOMEM);
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
struct rlimit lock_limit;
|
| 476 |
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
| 477 |
suggest = false;
|
|
@@ -479,6 +488,7 @@ struct llama_mlock::impl {
|
|
| 479 |
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
| 480 |
suggest = false;
|
| 481 |
}
|
|
|
|
| 482 |
|
| 483 |
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
| 484 |
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
|
|
|
| 8 |
#include <climits>
|
| 9 |
#include <stdexcept>
|
| 10 |
#include <cerrno>
|
| 11 |
+
#include <algorithm>
|
| 12 |
|
| 13 |
#ifdef __has_include
|
| 14 |
#if __has_include(<unistd.h>)
|
|
|
|
| 35 |
#include <io.h>
|
| 36 |
#endif
|
| 37 |
|
| 38 |
+
#if defined(__APPLE__)
|
| 39 |
+
#include <TargetConditionals.h>
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
// TODO: consider moving to llama-impl.h if needed in more places
|
| 43 |
#if defined(_WIN32)
|
| 44 |
static std::string llama_format_win_err(DWORD err) {
|
|
|
|
| 476 |
|
| 477 |
char* errmsg = std::strerror(errno);
|
| 478 |
bool suggest = (errno == ENOMEM);
|
| 479 |
+
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
|
| 480 |
+
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
|
| 481 |
+
// Skip resource limit checks on visionOS/tvOS
|
| 482 |
+
suggest = false;
|
| 483 |
+
#else
|
| 484 |
struct rlimit lock_limit;
|
| 485 |
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
| 486 |
suggest = false;
|
|
|
|
| 488 |
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
| 489 |
suggest = false;
|
| 490 |
}
|
| 491 |
+
#endif
|
| 492 |
|
| 493 |
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
| 494 |
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
examples/talk-llama/llama-mmap.h
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
|
|
|
| 3 |
#include <memory>
|
| 4 |
#include <vector>
|
| 5 |
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
+
#include <cstdint>
|
| 4 |
#include <memory>
|
| 5 |
#include <vector>
|
| 6 |
|
examples/talk-llama/llama-model-loader.cpp
CHANGED
|
@@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
|
|
| 445 |
std::vector<std::string> & splits,
|
| 446 |
bool use_mmap,
|
| 447 |
bool check_tensors,
|
| 448 |
-
const
|
|
|
|
| 449 |
int trace = 0;
|
| 450 |
if (getenv("LLAMA_TRACE")) {
|
| 451 |
trace = atoi(getenv("LLAMA_TRACE"));
|
|
@@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
|
|
| 457 |
}
|
| 458 |
}
|
| 459 |
|
|
|
|
|
|
|
| 460 |
// Load the main GGUF
|
| 461 |
struct ggml_context * ctx = NULL;
|
| 462 |
struct gguf_init_params params = {
|
|
@@ -600,7 +603,9 @@ llama_model_loader::llama_model_loader(
|
|
| 600 |
|
| 601 |
if (trace > 0) {
|
| 602 |
const uint16_t sid = w.idx;
|
| 603 |
-
LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__,
|
|
|
|
|
|
|
| 604 |
}
|
| 605 |
}
|
| 606 |
|
|
@@ -640,9 +645,9 @@ llama_model_loader::llama_model_loader(
|
|
| 640 |
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
|
| 641 |
|
| 642 |
{
|
| 643 |
-
|
| 644 |
-
if (
|
| 645 |
-
ftype = (llama_ftype)
|
| 646 |
}
|
| 647 |
}
|
| 648 |
|
|
|
|
| 445 |
std::vector<std::string> & splits,
|
| 446 |
bool use_mmap,
|
| 447 |
bool check_tensors,
|
| 448 |
+
const llama_model_kv_override * param_overrides_p,
|
| 449 |
+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
|
| 450 |
int trace = 0;
|
| 451 |
if (getenv("LLAMA_TRACE")) {
|
| 452 |
trace = atoi(getenv("LLAMA_TRACE"));
|
|
|
|
| 458 |
}
|
| 459 |
}
|
| 460 |
|
| 461 |
+
tensor_buft_overrides = param_tensor_buft_overrides_p;
|
| 462 |
+
|
| 463 |
// Load the main GGUF
|
| 464 |
struct ggml_context * ctx = NULL;
|
| 465 |
struct gguf_init_params params = {
|
|
|
|
| 603 |
|
| 604 |
if (trace > 0) {
|
| 605 |
const uint16_t sid = w.idx;
|
| 606 |
+
LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__,
|
| 607 |
+
sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(),
|
| 608 |
+
ggml_nbytes(tensor)/1024.0f/1024.0f);
|
| 609 |
}
|
| 610 |
}
|
| 611 |
|
|
|
|
| 645 |
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
|
| 646 |
|
| 647 |
{
|
| 648 |
+
uint32_t ftype_val = 0;
|
| 649 |
+
if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) {
|
| 650 |
+
ftype = (llama_ftype) ftype_val;
|
| 651 |
}
|
| 652 |
}
|
| 653 |
|
examples/talk-llama/llama-model-loader.h
CHANGED
|
@@ -77,8 +77,9 @@ struct llama_model_loader {
|
|
| 77 |
|
| 78 |
llama_mmaps mappings;
|
| 79 |
|
| 80 |
-
std::map<std::string,
|
| 81 |
-
std::unordered_map<std::string,
|
|
|
|
| 82 |
|
| 83 |
gguf_context_ptr meta;
|
| 84 |
std::vector<ggml_context_ptr> contexts;
|
|
@@ -95,7 +96,8 @@ struct llama_model_loader {
|
|
| 95 |
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
| 96 |
bool use_mmap,
|
| 97 |
bool check_tensors,
|
| 98 |
-
const
|
|
|
|
| 99 |
|
| 100 |
template<typename T>
|
| 101 |
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
|
|
|
| 77 |
|
| 78 |
llama_mmaps mappings;
|
| 79 |
|
| 80 |
+
std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
|
| 81 |
+
std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
|
| 82 |
+
const llama_model_tensor_buft_override * tensor_buft_overrides;
|
| 83 |
|
| 84 |
gguf_context_ptr meta;
|
| 85 |
std::vector<ggml_context_ptr> contexts;
|
|
|
|
| 96 |
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
| 97 |
bool use_mmap,
|
| 98 |
bool check_tensors,
|
| 99 |
+
const llama_model_kv_override * param_overrides_p,
|
| 100 |
+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
|
| 101 |
|
| 102 |
template<typename T>
|
| 103 |
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
examples/talk-llama/llama-model.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -2,7 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
#include "llama-arch.h"
|
|
|
|
| 5 |
#include "llama-hparams.h"
|
|
|
|
| 6 |
#include "llama-vocab.h"
|
| 7 |
|
| 8 |
#include <memory>
|
|
@@ -10,6 +12,8 @@
|
|
| 10 |
#include <unordered_map>
|
| 11 |
#include <vector>
|
| 12 |
|
|
|
|
|
|
|
| 13 |
struct llama_model_loader;
|
| 14 |
|
| 15 |
// available models
|
|
@@ -25,6 +29,7 @@ enum llm_type {
|
|
| 25 |
LLM_TYPE_109M,
|
| 26 |
LLM_TYPE_137M,
|
| 27 |
LLM_TYPE_160M,
|
|
|
|
| 28 |
LLM_TYPE_220M,
|
| 29 |
LLM_TYPE_250M,
|
| 30 |
LLM_TYPE_270M,
|
|
@@ -39,8 +44,10 @@ enum llm_type {
|
|
| 39 |
LLM_TYPE_1_4B,
|
| 40 |
LLM_TYPE_1_5B,
|
| 41 |
LLM_TYPE_1_6B,
|
|
|
|
| 42 |
LLM_TYPE_2B,
|
| 43 |
LLM_TYPE_2_8B,
|
|
|
|
| 44 |
LLM_TYPE_3B,
|
| 45 |
LLM_TYPE_4B,
|
| 46 |
LLM_TYPE_6B,
|
|
@@ -78,6 +85,9 @@ enum llm_type {
|
|
| 78 |
LLM_TYPE_10B_128x3_66B,
|
| 79 |
LLM_TYPE_57B_A14B,
|
| 80 |
LLM_TYPE_27B,
|
|
|
|
|
|
|
|
|
|
| 81 |
};
|
| 82 |
|
| 83 |
struct llama_layer_posnet {
|
|
@@ -161,6 +171,8 @@ struct llama_layer {
|
|
| 161 |
struct ggml_tensor * wq_b = nullptr;
|
| 162 |
struct ggml_tensor * wkv_a_mqa = nullptr;
|
| 163 |
struct ggml_tensor * wkv_b = nullptr;
|
|
|
|
|
|
|
| 164 |
struct ggml_tensor * wq_cross = nullptr;
|
| 165 |
struct ggml_tensor * wk_cross = nullptr;
|
| 166 |
struct ggml_tensor * wv_cross = nullptr;
|
|
@@ -256,6 +268,20 @@ struct llama_layer {
|
|
| 256 |
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
| 257 |
struct ggml_tensor * time_mix_gate = nullptr;
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
struct ggml_tensor * time_mix_ln = nullptr;
|
| 260 |
struct ggml_tensor * time_mix_ln_b = nullptr;
|
| 261 |
struct ggml_tensor * time_mix_output = nullptr;
|
|
@@ -347,7 +373,7 @@ struct llama_model {
|
|
| 347 |
std::string desc() const;
|
| 348 |
|
| 349 |
size_t size() const;
|
| 350 |
-
size_t
|
| 351 |
size_t n_devices() const;
|
| 352 |
|
| 353 |
// total number of parameters in the model
|
|
@@ -360,11 +386,26 @@ struct llama_model {
|
|
| 360 |
|
| 361 |
ggml_backend_buffer_type_t select_buft(int il) const;
|
| 362 |
|
|
|
|
|
|
|
| 363 |
const struct ggml_tensor * get_tensor(const char * name) const;
|
| 364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
private:
|
| 366 |
struct impl;
|
| 367 |
std::unique_ptr<impl> pimpl;
|
| 368 |
};
|
| 369 |
|
| 370 |
const char * llm_type_name(llm_type type);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
#include "llama-arch.h"
|
| 5 |
+
#include "llama-graph.h"
|
| 6 |
#include "llama-hparams.h"
|
| 7 |
+
#include "llama-memory.h"
|
| 8 |
#include "llama-vocab.h"
|
| 9 |
|
| 10 |
#include <memory>
|
|
|
|
| 12 |
#include <unordered_map>
|
| 13 |
#include <vector>
|
| 14 |
|
| 15 |
+
struct llama_cparams;
|
| 16 |
+
struct llama_ubatch;
|
| 17 |
struct llama_model_loader;
|
| 18 |
|
| 19 |
// available models
|
|
|
|
| 29 |
LLM_TYPE_109M,
|
| 30 |
LLM_TYPE_137M,
|
| 31 |
LLM_TYPE_160M,
|
| 32 |
+
LLM_TYPE_190M,
|
| 33 |
LLM_TYPE_220M,
|
| 34 |
LLM_TYPE_250M,
|
| 35 |
LLM_TYPE_270M,
|
|
|
|
| 44 |
LLM_TYPE_1_4B,
|
| 45 |
LLM_TYPE_1_5B,
|
| 46 |
LLM_TYPE_1_6B,
|
| 47 |
+
LLM_TYPE_1_8B,
|
| 48 |
LLM_TYPE_2B,
|
| 49 |
LLM_TYPE_2_8B,
|
| 50 |
+
LLM_TYPE_2_9B,
|
| 51 |
LLM_TYPE_3B,
|
| 52 |
LLM_TYPE_4B,
|
| 53 |
LLM_TYPE_6B,
|
|
|
|
| 85 |
LLM_TYPE_10B_128x3_66B,
|
| 86 |
LLM_TYPE_57B_A14B,
|
| 87 |
LLM_TYPE_27B,
|
| 88 |
+
LLM_TYPE_290B,
|
| 89 |
+
LLM_TYPE_17B_16E, // llama4 Scout
|
| 90 |
+
LLM_TYPE_17B_128E, // llama4 Maverick
|
| 91 |
};
|
| 92 |
|
| 93 |
struct llama_layer_posnet {
|
|
|
|
| 171 |
struct ggml_tensor * wq_b = nullptr;
|
| 172 |
struct ggml_tensor * wkv_a_mqa = nullptr;
|
| 173 |
struct ggml_tensor * wkv_b = nullptr;
|
| 174 |
+
struct ggml_tensor * wk_b = nullptr;
|
| 175 |
+
struct ggml_tensor * wv_b = nullptr;
|
| 176 |
struct ggml_tensor * wq_cross = nullptr;
|
| 177 |
struct ggml_tensor * wk_cross = nullptr;
|
| 178 |
struct ggml_tensor * wv_cross = nullptr;
|
|
|
|
| 268 |
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
| 269 |
struct ggml_tensor * time_mix_gate = nullptr;
|
| 270 |
|
| 271 |
+
// rwkv7
|
| 272 |
+
struct ggml_tensor * time_mix_w0 = nullptr;
|
| 273 |
+
struct ggml_tensor * time_mix_a0 = nullptr;
|
| 274 |
+
struct ggml_tensor * time_mix_a1 = nullptr;
|
| 275 |
+
struct ggml_tensor * time_mix_a2 = nullptr;
|
| 276 |
+
struct ggml_tensor * time_mix_v0 = nullptr;
|
| 277 |
+
struct ggml_tensor * time_mix_v1 = nullptr;
|
| 278 |
+
struct ggml_tensor * time_mix_v2 = nullptr;
|
| 279 |
+
struct ggml_tensor * time_mix_g1 = nullptr;
|
| 280 |
+
struct ggml_tensor * time_mix_g2 = nullptr;
|
| 281 |
+
struct ggml_tensor * time_mix_k_k = nullptr;
|
| 282 |
+
struct ggml_tensor * time_mix_k_a = nullptr;
|
| 283 |
+
struct ggml_tensor * time_mix_r_k = nullptr;
|
| 284 |
+
|
| 285 |
struct ggml_tensor * time_mix_ln = nullptr;
|
| 286 |
struct ggml_tensor * time_mix_ln_b = nullptr;
|
| 287 |
struct ggml_tensor * time_mix_output = nullptr;
|
|
|
|
| 373 |
std::string desc() const;
|
| 374 |
|
| 375 |
size_t size() const;
|
| 376 |
+
size_t n_tensors() const;
|
| 377 |
size_t n_devices() const;
|
| 378 |
|
| 379 |
// total number of parameters in the model
|
|
|
|
| 386 |
|
| 387 |
ggml_backend_buffer_type_t select_buft(int il) const;
|
| 388 |
|
| 389 |
+
bool has_tensor_overrides() const;
|
| 390 |
+
|
| 391 |
const struct ggml_tensor * get_tensor(const char * name) const;
|
| 392 |
|
| 393 |
+
// TODO: move this to new llm_arch_model_i interface
|
| 394 |
+
llama_memory_i * create_memory() const; // TODO: params
|
| 395 |
+
|
| 396 |
+
// TODO: move this to new llm_arch_model_i interface
|
| 397 |
+
llm_graph_result_ptr build_graph(
|
| 398 |
+
const llm_graph_params & params,
|
| 399 |
+
ggml_cgraph * gf,
|
| 400 |
+
llm_graph_type type) const;
|
| 401 |
+
|
| 402 |
private:
|
| 403 |
struct impl;
|
| 404 |
std::unique_ptr<impl> pimpl;
|
| 405 |
};
|
| 406 |
|
| 407 |
const char * llm_type_name(llm_type type);
|
| 408 |
+
|
| 409 |
+
// For internal test use
|
| 410 |
+
// TODO: remove
|
| 411 |
+
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);
|
examples/talk-llama/llama-quant.cpp
CHANGED
|
@@ -10,6 +10,7 @@
|
|
| 10 |
#include <cinttypes>
|
| 11 |
#include <fstream>
|
| 12 |
#include <mutex>
|
|
|
|
| 13 |
#include <thread>
|
| 14 |
#include <unordered_map>
|
| 15 |
|
|
@@ -47,8 +48,14 @@ struct quantize_state_impl {
|
|
| 47 |
{}
|
| 48 |
};
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
static void llama_tensor_dequantize_impl(
|
| 51 |
-
|
| 52 |
const size_t nelements, const int nthread
|
| 53 |
) {
|
| 54 |
if (output.size() < nelements) {
|
|
@@ -527,7 +534,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 527 |
}
|
| 528 |
|
| 529 |
std::vector<std::string> splits = {};
|
| 530 |
-
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
|
| 531 |
ml.init_mappings(false); // no prefetching
|
| 532 |
|
| 533 |
llama_model model(llama_model_default_params());
|
|
@@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 536 |
model.load_hparams(ml);
|
| 537 |
model.load_stats (ml);
|
| 538 |
|
| 539 |
-
|
| 540 |
|
| 541 |
if (params->only_copy) {
|
| 542 |
ftype = ml.ftype;
|
|
@@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 661 |
// populate the original tensors so we get an initial meta data
|
| 662 |
for (const auto * it : tensors) {
|
| 663 |
uint16_t i_split = params->keep_split ? it->idx : 0;
|
| 664 |
-
|
| 665 |
if (!ctx_outs[i_split]) {
|
| 666 |
ctx_outs[i_split].reset(gguf_init_empty());
|
| 667 |
}
|
|
@@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 710 |
new_ofstream(0);
|
| 711 |
for (const auto * it : tensors) {
|
| 712 |
const auto & weight = *it;
|
| 713 |
-
|
| 714 |
if (weight.idx != cur_split && params->keep_split) {
|
| 715 |
close_ofstream();
|
| 716 |
new_ofstream(weight.idx);
|
|
@@ -756,10 +763,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 756 |
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 757 |
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
| 758 |
|
| 759 |
-
// do not quantize RWKV's
|
| 760 |
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
|
|
|
| 761 |
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
| 762 |
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
| 764 |
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
| 765 |
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
|
@@ -767,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 767 |
// do not quantize relative position bias (T5)
|
| 768 |
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
| 769 |
|
| 770 |
-
|
| 771 |
void * new_data;
|
| 772 |
size_t new_size;
|
| 773 |
|
|
@@ -777,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 777 |
// get more optimal quantization type based on the tensor shape, layer, etc.
|
| 778 |
if (!params->pure && ggml_is_quantized(default_type)) {
|
| 779 |
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
}
|
| 781 |
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
| 782 |
new_type = params->token_embedding_type;
|
|
@@ -901,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 901 |
// interface implementation
|
| 902 |
//
|
| 903 |
|
| 904 |
-
|
| 905 |
-
|
| 906 |
/*.nthread =*/ 0,
|
| 907 |
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
| 908 |
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
|
@@ -914,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
|
| 914 |
/*.keep_split =*/ false,
|
| 915 |
/*.imatrix =*/ nullptr,
|
| 916 |
/*.kv_overrides =*/ nullptr,
|
|
|
|
| 917 |
};
|
| 918 |
|
| 919 |
return result;
|
|
|
|
| 10 |
#include <cinttypes>
|
| 11 |
#include <fstream>
|
| 12 |
#include <mutex>
|
| 13 |
+
#include <regex>
|
| 14 |
#include <thread>
|
| 15 |
#include <unordered_map>
|
| 16 |
|
|
|
|
| 48 |
{}
|
| 49 |
};
|
| 50 |
|
| 51 |
+
// changes to this struct must be replicated in quantize.cpp
|
| 52 |
+
struct tensor_quantization {
|
| 53 |
+
std::string name;
|
| 54 |
+
ggml_type quant = GGML_TYPE_COUNT;
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
static void llama_tensor_dequantize_impl(
|
| 58 |
+
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
| 59 |
const size_t nelements, const int nthread
|
| 60 |
) {
|
| 61 |
if (output.size() < nelements) {
|
|
|
|
| 534 |
}
|
| 535 |
|
| 536 |
std::vector<std::string> splits = {};
|
| 537 |
+
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
|
| 538 |
ml.init_mappings(false); // no prefetching
|
| 539 |
|
| 540 |
llama_model model(llama_model_default_params());
|
|
|
|
| 543 |
model.load_hparams(ml);
|
| 544 |
model.load_stats (ml);
|
| 545 |
|
| 546 |
+
quantize_state_impl qs(model, params);
|
| 547 |
|
| 548 |
if (params->only_copy) {
|
| 549 |
ftype = ml.ftype;
|
|
|
|
| 668 |
// populate the original tensors so we get an initial meta data
|
| 669 |
for (const auto * it : tensors) {
|
| 670 |
uint16_t i_split = params->keep_split ? it->idx : 0;
|
| 671 |
+
ggml_tensor * tensor = it->tensor;
|
| 672 |
if (!ctx_outs[i_split]) {
|
| 673 |
ctx_outs[i_split].reset(gguf_init_empty());
|
| 674 |
}
|
|
|
|
| 717 |
new_ofstream(0);
|
| 718 |
for (const auto * it : tensors) {
|
| 719 |
const auto & weight = *it;
|
| 720 |
+
ggml_tensor * tensor = weight.tensor;
|
| 721 |
if (weight.idx != cur_split && params->keep_split) {
|
| 722 |
close_ofstream();
|
| 723 |
new_ofstream(weight.idx);
|
|
|
|
| 763 |
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 764 |
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
| 765 |
|
| 766 |
+
// do not quantize RWKV's small yet 2D weights
|
| 767 |
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
| 768 |
+
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
|
| 769 |
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
| 770 |
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
| 771 |
+
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
|
| 772 |
+
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
|
| 773 |
+
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
|
| 774 |
+
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
|
| 775 |
+
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
|
| 776 |
+
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
|
| 777 |
+
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
|
| 778 |
+
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
|
| 779 |
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
| 780 |
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
| 781 |
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
|
|
|
| 783 |
// do not quantize relative position bias (T5)
|
| 784 |
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
| 785 |
|
| 786 |
+
ggml_type new_type;
|
| 787 |
void * new_data;
|
| 788 |
size_t new_size;
|
| 789 |
|
|
|
|
| 793 |
// get more optimal quantization type based on the tensor shape, layer, etc.
|
| 794 |
if (!params->pure && ggml_is_quantized(default_type)) {
|
| 795 |
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
| 796 |
+
// unless the user specifies a type
|
| 797 |
+
if (params->tensor_types) {
|
| 798 |
+
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
| 799 |
+
for (const auto & [tname, qtype] : tensor_types) {
|
| 800 |
+
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
|
| 801 |
+
if (qtype != new_type) {
|
| 802 |
+
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
|
| 803 |
+
}
|
| 804 |
+
new_type = qtype;
|
| 805 |
+
break;
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
}
|
| 809 |
}
|
| 810 |
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
| 811 |
new_type = params->token_embedding_type;
|
|
|
|
| 930 |
// interface implementation
|
| 931 |
//
|
| 932 |
|
| 933 |
+
llama_model_quantize_params llama_model_quantize_default_params() {
|
| 934 |
+
llama_model_quantize_params result = {
|
| 935 |
/*.nthread =*/ 0,
|
| 936 |
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
| 937 |
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
|
|
|
| 943 |
/*.keep_split =*/ false,
|
| 944 |
/*.imatrix =*/ nullptr,
|
| 945 |
/*.kv_overrides =*/ nullptr,
|
| 946 |
+
/*.tensor_type =*/ nullptr,
|
| 947 |
};
|
| 948 |
|
| 949 |
return result;
|
examples/talk-llama/llama-sampling.cpp
CHANGED
|
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
|
| 316 |
|
| 317 |
// llama_sampler API
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
| 320 |
if (!smpl->iface) {
|
| 321 |
return "(null)";
|
|
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
|
| 347 |
}
|
| 348 |
|
| 349 |
if (smpl->ctx == nullptr) {
|
| 350 |
-
return
|
| 351 |
/* .iface = */ smpl->iface,
|
| 352 |
-
/* .ctx = */ nullptr
|
| 353 |
-
|
| 354 |
}
|
| 355 |
|
| 356 |
GGML_ABORT("the sampler does not support cloning");
|
|
@@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
|
|
| 472 |
};
|
| 473 |
|
| 474 |
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
| 475 |
-
return
|
| 476 |
/* .iface = */ &llama_sampler_chain_i,
|
| 477 |
/* .ctx = */ new llama_sampler_chain {
|
| 478 |
/* .params = */ params,
|
| 479 |
/* .samplers = */ {},
|
| 480 |
/* .t_sample_us = */ 0,
|
| 481 |
/* .n_sample = */ 0,
|
| 482 |
-
}
|
| 483 |
-
|
| 484 |
}
|
| 485 |
|
| 486 |
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
@@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
|
|
| 546 |
};
|
| 547 |
|
| 548 |
struct llama_sampler * llama_sampler_init_greedy() {
|
| 549 |
-
return
|
| 550 |
/* .iface = */ &llama_sampler_greedy_i,
|
| 551 |
-
/* .ctx = */ nullptr
|
| 552 |
-
|
| 553 |
}
|
| 554 |
|
| 555 |
// dist
|
|
@@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|
| 608 |
|
| 609 |
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
| 610 |
auto seed_cur = get_rng_seed(seed);
|
| 611 |
-
return
|
| 612 |
/* .iface = */ &llama_sampler_dist_i,
|
| 613 |
/* .ctx = */ new llama_sampler_dist {
|
| 614 |
/* .seed = */ seed,
|
| 615 |
/* .seed_cur = */ seed_cur,
|
| 616 |
/* .rng = */ std::mt19937(seed_cur),
|
| 617 |
-
}
|
| 618 |
-
|
| 619 |
}
|
| 620 |
|
| 621 |
// softmax
|
|
@@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
|
|
| 638 |
};
|
| 639 |
|
| 640 |
struct llama_sampler * llama_sampler_init_softmax() {
|
| 641 |
-
return
|
| 642 |
/* .iface = */ &llama_sampler_softmax_i,
|
| 643 |
-
/* .ctx = */ nullptr
|
| 644 |
-
|
| 645 |
}
|
| 646 |
|
| 647 |
// top-k
|
|
@@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
|
|
| 678 |
};
|
| 679 |
|
| 680 |
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
| 681 |
-
return
|
| 682 |
/* .iface = */ &llama_sampler_top_k_i,
|
| 683 |
/* .ctx = */ new llama_sampler_top_k {
|
| 684 |
/* .k = */ k,
|
| 685 |
-
}
|
| 686 |
-
|
| 687 |
}
|
| 688 |
|
| 689 |
// top-p
|
|
@@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
|
|
| 744 |
};
|
| 745 |
|
| 746 |
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
| 747 |
-
return
|
| 748 |
/* .iface = */ &llama_sampler_top_p_i,
|
| 749 |
/* .ctx = */ new llama_sampler_top_p {
|
| 750 |
/* .p = */ p,
|
| 751 |
/* .min_keep = */ min_keep,
|
| 752 |
-
}
|
| 753 |
-
|
| 754 |
}
|
| 755 |
|
| 756 |
// min-p
|
|
@@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
|
|
| 840 |
};
|
| 841 |
|
| 842 |
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
| 843 |
-
return
|
| 844 |
/* .iface = */ &llama_sampler_min_p_i,
|
| 845 |
/* .ctx = */ new llama_sampler_min_p {
|
| 846 |
/* .p = */ p,
|
| 847 |
/* .min_keep = */ min_keep,
|
| 848 |
-
}
|
| 849 |
-
|
| 850 |
}
|
| 851 |
|
| 852 |
// typical
|
|
@@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
|
|
| 939 |
};
|
| 940 |
|
| 941 |
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
| 942 |
-
return
|
| 943 |
/* .iface = */ &llama_sampler_typical_i,
|
| 944 |
/* .ctx = */ new llama_sampler_typical {
|
| 945 |
/* .p = */ p,
|
| 946 |
/* .min_keep = */ min_keep,
|
| 947 |
-
}
|
| 948 |
-
|
| 949 |
}
|
| 950 |
|
| 951 |
// temp
|
|
@@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
|
|
| 983 |
};
|
| 984 |
|
| 985 |
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
| 986 |
-
return
|
| 987 |
/* .iface = */ &llama_sampler_temp_i,
|
| 988 |
/* .ctx = */ new llama_sampler_temp {
|
| 989 |
/*.temp = */ temp,
|
| 990 |
-
}
|
| 991 |
-
|
| 992 |
}
|
| 993 |
|
| 994 |
// temp-ext
|
|
@@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|
| 1093 |
};
|
| 1094 |
|
| 1095 |
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
| 1096 |
-
return
|
| 1097 |
/* .iface = */ &llama_sampler_temp_ext_i,
|
| 1098 |
/* .ctx = */ new llama_sampler_temp_ext {
|
| 1099 |
/* .temp = */ temp,
|
| 1100 |
/* .delta = */ delta,
|
| 1101 |
/* .exponent = */ exponent,
|
| 1102 |
-
}
|
| 1103 |
-
|
| 1104 |
}
|
| 1105 |
|
| 1106 |
// xtc
|
|
@@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
|
| 1185 |
|
| 1186 |
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
| 1187 |
auto seed_cur = get_rng_seed(seed);
|
| 1188 |
-
return
|
| 1189 |
/* .iface = */ &llama_sampler_xtc_i,
|
| 1190 |
/* .ctx = */ new llama_sampler_xtc {
|
| 1191 |
/* .probability = */ p,
|
|
@@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
|
| 1194 |
/* .seed = */ seed,
|
| 1195 |
/* .seed_cur = */ seed_cur,
|
| 1196 |
/* .rng = */ std::mt19937(seed_cur),
|
| 1197 |
-
}
|
| 1198 |
-
|
| 1199 |
}
|
| 1200 |
|
| 1201 |
// mirostat
|
|
@@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|
| 1292 |
|
| 1293 |
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
| 1294 |
auto seed_cur = get_rng_seed(seed);
|
| 1295 |
-
return
|
| 1296 |
/* .iface = */ &llama_sampler_mirostat_i,
|
| 1297 |
/* .ctx = */ new llama_sampler_mirostat {
|
| 1298 |
/* .n_vocab = */ n_vocab,
|
|
@@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
|
|
| 1303 |
/* .m = */ m,
|
| 1304 |
/* .mu = */ 2.0f*tau,
|
| 1305 |
/* .rng = */ std::mt19937(seed_cur),
|
| 1306 |
-
}
|
| 1307 |
-
|
| 1308 |
}
|
| 1309 |
|
| 1310 |
// mirostat v2
|
|
@@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|
| 1391 |
|
| 1392 |
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
| 1393 |
auto seed_cur = get_rng_seed(seed);
|
| 1394 |
-
return
|
| 1395 |
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
| 1396 |
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
| 1397 |
/* .seed = */ seed,
|
|
@@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
|
|
| 1400 |
/* .eta = */ eta,
|
| 1401 |
/* .mu = */ 2.0f*tau,
|
| 1402 |
/* .rng = */ std::mt19937(seed_cur),
|
| 1403 |
-
}
|
| 1404 |
-
|
| 1405 |
}
|
| 1406 |
|
| 1407 |
// grammar
|
|
@@ -1442,7 +1449,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
| 1442 |
const char ** trigger_words,
|
| 1443 |
size_t num_trigger_words,
|
| 1444 |
const llama_token * trigger_tokens,
|
| 1445 |
-
size_t num_trigger_tokens
|
|
|
|
|
|
|
| 1446 |
|
| 1447 |
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
| 1448 |
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
|
@@ -1450,12 +1459,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|
| 1450 |
return;
|
| 1451 |
}
|
| 1452 |
|
| 1453 |
-
std::vector<const char *>
|
| 1454 |
-
|
| 1455 |
-
|
|
|
|
| 1456 |
}
|
|
|
|
| 1457 |
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
| 1458 |
-
ctx->grammar->lazy,
|
| 1459 |
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
| 1460 |
|
| 1461 |
llama_grammar_free_impl(ctx->grammar);
|
|
@@ -1465,7 +1476,8 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|
| 1465 |
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
| 1466 |
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
| 1467 |
|
| 1468 |
-
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
|
|
|
|
| 1469 |
|
| 1470 |
// copy the state
|
| 1471 |
{
|
|
@@ -1509,16 +1521,38 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
| 1509 |
const char ** trigger_words,
|
| 1510 |
size_t num_trigger_words,
|
| 1511 |
const llama_token * trigger_tokens,
|
| 1512 |
-
size_t num_trigger_tokens
|
|
|
|
|
|
|
| 1513 |
auto * ctx = new llama_sampler_grammar;
|
| 1514 |
|
| 1515 |
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1516 |
*ctx = {
|
| 1517 |
/* .vocab = */ vocab,
|
| 1518 |
/* .grammar_str = */ grammar_str,
|
| 1519 |
/* .grammar_root = */ grammar_root,
|
| 1520 |
-
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy,
|
| 1521 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1522 |
} else {
|
| 1523 |
*ctx = {
|
| 1524 |
/* .vocab = */ vocab,
|
|
@@ -1528,17 +1562,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
| 1528 |
};
|
| 1529 |
}
|
| 1530 |
|
| 1531 |
-
return
|
| 1532 |
/* .iface = */ &llama_sampler_grammar_i,
|
| 1533 |
-
/* .ctx = */ ctx
|
| 1534 |
-
|
| 1535 |
}
|
| 1536 |
|
| 1537 |
struct llama_sampler * llama_sampler_init_grammar(
|
| 1538 |
const struct llama_vocab * vocab,
|
| 1539 |
const char * grammar_str,
|
| 1540 |
const char * grammar_root) {
|
| 1541 |
-
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
|
| 1542 |
}
|
| 1543 |
|
| 1544 |
struct llama_sampler * llama_sampler_init_grammar_lazy(
|
|
@@ -1549,7 +1583,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy(
|
|
| 1549 |
size_t num_trigger_words,
|
| 1550 |
const llama_token * trigger_tokens,
|
| 1551 |
size_t num_trigger_tokens) {
|
| 1552 |
-
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1553 |
}
|
| 1554 |
|
| 1555 |
// penalties
|
|
@@ -1678,7 +1723,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
| 1678 |
float penalty_present) {
|
| 1679 |
penalty_last_n = std::max(penalty_last_n, 0);
|
| 1680 |
|
| 1681 |
-
return
|
| 1682 |
/* .iface = */ &llama_sampler_penalties_i,
|
| 1683 |
/* .ctx = */ new llama_sampler_penalties {
|
| 1684 |
/* .penalty_last_n = */ penalty_last_n,
|
|
@@ -1687,8 +1732,75 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|
| 1687 |
/* .penalty_present = */ penalty_present,
|
| 1688 |
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
| 1689 |
/* .token_count = */ {},
|
| 1690 |
-
}
|
| 1691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1692 |
}
|
| 1693 |
|
| 1694 |
// DRY
|
|
@@ -2041,7 +2153,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|
| 2041 |
}
|
| 2042 |
}
|
| 2043 |
|
| 2044 |
-
return
|
| 2045 |
/* .iface = */ &llama_sampler_dry_i,
|
| 2046 |
/* .ctx = */ new llama_sampler_dry {
|
| 2047 |
/* .total_context_size = */ context_size,
|
|
@@ -2053,8 +2165,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|
| 2053 |
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
| 2054 |
/* .dry_max_token_repeat = */ {},
|
| 2055 |
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
| 2056 |
-
}
|
| 2057 |
-
|
| 2058 |
}
|
| 2059 |
|
| 2060 |
// wrapper for test-sampling.cpp
|
|
@@ -2155,14 +2267,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|
| 2155 |
int32_t n_vocab,
|
| 2156 |
int32_t n_logit_bias,
|
| 2157 |
const llama_logit_bias * logit_bias) {
|
| 2158 |
-
return
|
| 2159 |
/* .iface = */ &llama_sampler_logit_bias_i,
|
| 2160 |
/* .ctx = */ new llama_sampler_logit_bias {
|
| 2161 |
/* .n_vocab = */ n_vocab,
|
| 2162 |
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
| 2163 |
/* .to_search = */ {},
|
| 2164 |
-
}
|
| 2165 |
-
|
| 2166 |
}
|
| 2167 |
|
| 2168 |
// infill
|
|
@@ -2377,14 +2489,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
|
|
| 2377 |
};
|
| 2378 |
|
| 2379 |
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
| 2380 |
-
return
|
| 2381 |
/* .iface = */ &llama_sampler_infill_i,
|
| 2382 |
/* .ctx = */ new llama_sampler_infill {
|
| 2383 |
/* .vocab = */ vocab,
|
| 2384 |
/* .buf0 = */ std::vector<char>(512),
|
| 2385 |
/* .buf1 = */ std::vector<char>(512),
|
| 2386 |
-
}
|
| 2387 |
-
|
| 2388 |
}
|
| 2389 |
|
| 2390 |
// utils
|
|
|
|
| 316 |
|
| 317 |
// llama_sampler API
|
| 318 |
|
| 319 |
+
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
|
| 320 |
+
return new llama_sampler {
|
| 321 |
+
/* .iface = */ iface,
|
| 322 |
+
/* .ctx = */ ctx,
|
| 323 |
+
};
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
| 327 |
if (!smpl->iface) {
|
| 328 |
return "(null)";
|
|
|
|
| 354 |
}
|
| 355 |
|
| 356 |
if (smpl->ctx == nullptr) {
|
| 357 |
+
return llama_sampler_init(
|
| 358 |
/* .iface = */ smpl->iface,
|
| 359 |
+
/* .ctx = */ nullptr
|
| 360 |
+
);
|
| 361 |
}
|
| 362 |
|
| 363 |
GGML_ABORT("the sampler does not support cloning");
|
|
|
|
| 479 |
};
|
| 480 |
|
| 481 |
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
| 482 |
+
return llama_sampler_init(
|
| 483 |
/* .iface = */ &llama_sampler_chain_i,
|
| 484 |
/* .ctx = */ new llama_sampler_chain {
|
| 485 |
/* .params = */ params,
|
| 486 |
/* .samplers = */ {},
|
| 487 |
/* .t_sample_us = */ 0,
|
| 488 |
/* .n_sample = */ 0,
|
| 489 |
+
}
|
| 490 |
+
);
|
| 491 |
}
|
| 492 |
|
| 493 |
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
|
|
|
| 553 |
};
|
| 554 |
|
| 555 |
struct llama_sampler * llama_sampler_init_greedy() {
|
| 556 |
+
return llama_sampler_init(
|
| 557 |
/* .iface = */ &llama_sampler_greedy_i,
|
| 558 |
+
/* .ctx = */ nullptr
|
| 559 |
+
);
|
| 560 |
}
|
| 561 |
|
| 562 |
// dist
|
|
|
|
| 615 |
|
| 616 |
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
| 617 |
auto seed_cur = get_rng_seed(seed);
|
| 618 |
+
return llama_sampler_init(
|
| 619 |
/* .iface = */ &llama_sampler_dist_i,
|
| 620 |
/* .ctx = */ new llama_sampler_dist {
|
| 621 |
/* .seed = */ seed,
|
| 622 |
/* .seed_cur = */ seed_cur,
|
| 623 |
/* .rng = */ std::mt19937(seed_cur),
|
| 624 |
+
}
|
| 625 |
+
);
|
| 626 |
}
|
| 627 |
|
| 628 |
// softmax
|
|
|
|
| 645 |
};
|
| 646 |
|
| 647 |
struct llama_sampler * llama_sampler_init_softmax() {
|
| 648 |
+
return llama_sampler_init(
|
| 649 |
/* .iface = */ &llama_sampler_softmax_i,
|
| 650 |
+
/* .ctx = */ nullptr
|
| 651 |
+
);
|
| 652 |
}
|
| 653 |
|
| 654 |
// top-k
|
|
|
|
| 685 |
};
|
| 686 |
|
| 687 |
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
| 688 |
+
return llama_sampler_init(
|
| 689 |
/* .iface = */ &llama_sampler_top_k_i,
|
| 690 |
/* .ctx = */ new llama_sampler_top_k {
|
| 691 |
/* .k = */ k,
|
| 692 |
+
}
|
| 693 |
+
);
|
| 694 |
}
|
| 695 |
|
| 696 |
// top-p
|
|
|
|
| 751 |
};
|
| 752 |
|
| 753 |
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
| 754 |
+
return llama_sampler_init(
|
| 755 |
/* .iface = */ &llama_sampler_top_p_i,
|
| 756 |
/* .ctx = */ new llama_sampler_top_p {
|
| 757 |
/* .p = */ p,
|
| 758 |
/* .min_keep = */ min_keep,
|
| 759 |
+
}
|
| 760 |
+
);
|
| 761 |
}
|
| 762 |
|
| 763 |
// min-p
|
|
|
|
| 847 |
};
|
| 848 |
|
| 849 |
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
| 850 |
+
return llama_sampler_init(
|
| 851 |
/* .iface = */ &llama_sampler_min_p_i,
|
| 852 |
/* .ctx = */ new llama_sampler_min_p {
|
| 853 |
/* .p = */ p,
|
| 854 |
/* .min_keep = */ min_keep,
|
| 855 |
+
}
|
| 856 |
+
);
|
| 857 |
}
|
| 858 |
|
| 859 |
// typical
|
|
|
|
| 946 |
};
|
| 947 |
|
| 948 |
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
| 949 |
+
return llama_sampler_init(
|
| 950 |
/* .iface = */ &llama_sampler_typical_i,
|
| 951 |
/* .ctx = */ new llama_sampler_typical {
|
| 952 |
/* .p = */ p,
|
| 953 |
/* .min_keep = */ min_keep,
|
| 954 |
+
}
|
| 955 |
+
);
|
| 956 |
}
|
| 957 |
|
| 958 |
// temp
|
|
|
|
| 990 |
};
|
| 991 |
|
| 992 |
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
| 993 |
+
return llama_sampler_init(
|
| 994 |
/* .iface = */ &llama_sampler_temp_i,
|
| 995 |
/* .ctx = */ new llama_sampler_temp {
|
| 996 |
/*.temp = */ temp,
|
| 997 |
+
}
|
| 998 |
+
);
|
| 999 |
}
|
| 1000 |
|
| 1001 |
// temp-ext
|
|
|
|
| 1100 |
};
|
| 1101 |
|
| 1102 |
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
| 1103 |
+
return llama_sampler_init(
|
| 1104 |
/* .iface = */ &llama_sampler_temp_ext_i,
|
| 1105 |
/* .ctx = */ new llama_sampler_temp_ext {
|
| 1106 |
/* .temp = */ temp,
|
| 1107 |
/* .delta = */ delta,
|
| 1108 |
/* .exponent = */ exponent,
|
| 1109 |
+
}
|
| 1110 |
+
);
|
| 1111 |
}
|
| 1112 |
|
| 1113 |
// xtc
|
|
|
|
| 1192 |
|
| 1193 |
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
| 1194 |
auto seed_cur = get_rng_seed(seed);
|
| 1195 |
+
return llama_sampler_init(
|
| 1196 |
/* .iface = */ &llama_sampler_xtc_i,
|
| 1197 |
/* .ctx = */ new llama_sampler_xtc {
|
| 1198 |
/* .probability = */ p,
|
|
|
|
| 1201 |
/* .seed = */ seed,
|
| 1202 |
/* .seed_cur = */ seed_cur,
|
| 1203 |
/* .rng = */ std::mt19937(seed_cur),
|
| 1204 |
+
}
|
| 1205 |
+
);
|
| 1206 |
}
|
| 1207 |
|
| 1208 |
// mirostat
|
|
|
|
| 1299 |
|
| 1300 |
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
| 1301 |
auto seed_cur = get_rng_seed(seed);
|
| 1302 |
+
return llama_sampler_init(
|
| 1303 |
/* .iface = */ &llama_sampler_mirostat_i,
|
| 1304 |
/* .ctx = */ new llama_sampler_mirostat {
|
| 1305 |
/* .n_vocab = */ n_vocab,
|
|
|
|
| 1310 |
/* .m = */ m,
|
| 1311 |
/* .mu = */ 2.0f*tau,
|
| 1312 |
/* .rng = */ std::mt19937(seed_cur),
|
| 1313 |
+
}
|
| 1314 |
+
);
|
| 1315 |
}
|
| 1316 |
|
| 1317 |
// mirostat v2
|
|
|
|
| 1398 |
|
| 1399 |
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
| 1400 |
auto seed_cur = get_rng_seed(seed);
|
| 1401 |
+
return llama_sampler_init(
|
| 1402 |
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
| 1403 |
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
| 1404 |
/* .seed = */ seed,
|
|
|
|
| 1407 |
/* .eta = */ eta,
|
| 1408 |
/* .mu = */ 2.0f*tau,
|
| 1409 |
/* .rng = */ std::mt19937(seed_cur),
|
| 1410 |
+
}
|
| 1411 |
+
);
|
| 1412 |
}
|
| 1413 |
|
| 1414 |
// grammar
|
|
|
|
| 1449 |
const char ** trigger_words,
|
| 1450 |
size_t num_trigger_words,
|
| 1451 |
const llama_token * trigger_tokens,
|
| 1452 |
+
size_t num_trigger_tokens,
|
| 1453 |
+
const char ** trigger_patterns,
|
| 1454 |
+
size_t num_trigger_patterns);
|
| 1455 |
|
| 1456 |
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
| 1457 |
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
|
|
|
| 1459 |
return;
|
| 1460 |
}
|
| 1461 |
|
| 1462 |
+
std::vector<const char *> trigger_patterns_c;
|
| 1463 |
+
trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
|
| 1464 |
+
for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
|
| 1465 |
+
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
| 1466 |
}
|
| 1467 |
+
|
| 1468 |
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
| 1469 |
+
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
| 1470 |
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
| 1471 |
|
| 1472 |
llama_grammar_free_impl(ctx->grammar);
|
|
|
|
| 1476 |
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
| 1477 |
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
| 1478 |
|
| 1479 |
+
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
|
| 1480 |
+
GGML_ASSERT(result);
|
| 1481 |
|
| 1482 |
// copy the state
|
| 1483 |
{
|
|
|
|
| 1521 |
const char ** trigger_words,
|
| 1522 |
size_t num_trigger_words,
|
| 1523 |
const llama_token * trigger_tokens,
|
| 1524 |
+
size_t num_trigger_tokens,
|
| 1525 |
+
const char ** trigger_patterns,
|
| 1526 |
+
size_t num_trigger_patterns) {
|
| 1527 |
auto * ctx = new llama_sampler_grammar;
|
| 1528 |
|
| 1529 |
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
| 1530 |
+
// TODO: remove trigger_words support.
|
| 1531 |
+
if (trigger_words != nullptr && num_trigger_words > 0) {
|
| 1532 |
+
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
| 1533 |
+
std::string trigger_pattern("[\\s\\S]*?(");
|
| 1534 |
+
for (size_t i = 0; i < num_trigger_words; ++i) {
|
| 1535 |
+
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
| 1536 |
+
if (i > 0) {
|
| 1537 |
+
trigger_pattern += "|";
|
| 1538 |
+
}
|
| 1539 |
+
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
| 1540 |
+
}
|
| 1541 |
+
trigger_pattern += ")[\\s\\S]*";
|
| 1542 |
+
auto trigger_pattern_c = trigger_pattern.c_str();
|
| 1543 |
+
trigger_patterns = &trigger_pattern_c;
|
| 1544 |
+
num_trigger_patterns = 1;
|
| 1545 |
+
}
|
| 1546 |
*ctx = {
|
| 1547 |
/* .vocab = */ vocab,
|
| 1548 |
/* .grammar_str = */ grammar_str,
|
| 1549 |
/* .grammar_root = */ grammar_root,
|
| 1550 |
+
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
| 1551 |
};
|
| 1552 |
+
if (!ctx->grammar) {
|
| 1553 |
+
delete ctx;
|
| 1554 |
+
return nullptr;
|
| 1555 |
+
}
|
| 1556 |
} else {
|
| 1557 |
*ctx = {
|
| 1558 |
/* .vocab = */ vocab,
|
|
|
|
| 1562 |
};
|
| 1563 |
}
|
| 1564 |
|
| 1565 |
+
return llama_sampler_init(
|
| 1566 |
/* .iface = */ &llama_sampler_grammar_i,
|
| 1567 |
+
/* .ctx = */ ctx
|
| 1568 |
+
);
|
| 1569 |
}
|
| 1570 |
|
| 1571 |
struct llama_sampler * llama_sampler_init_grammar(
|
| 1572 |
const struct llama_vocab * vocab,
|
| 1573 |
const char * grammar_str,
|
| 1574 |
const char * grammar_root) {
|
| 1575 |
+
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
|
| 1576 |
}
|
| 1577 |
|
| 1578 |
struct llama_sampler * llama_sampler_init_grammar_lazy(
|
|
|
|
| 1583 |
size_t num_trigger_words,
|
| 1584 |
const llama_token * trigger_tokens,
|
| 1585 |
size_t num_trigger_tokens) {
|
| 1586 |
+
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
|
| 1587 |
+
}
|
| 1588 |
+
|
| 1589 |
+
struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
| 1590 |
+
const struct llama_vocab * vocab,
|
| 1591 |
+
const char * grammar_str,
|
| 1592 |
+
const char * grammar_root,
|
| 1593 |
+
const char ** trigger_patterns,
|
| 1594 |
+
size_t num_trigger_patterns,
|
| 1595 |
+
const llama_token * trigger_tokens,
|
| 1596 |
+
size_t num_trigger_tokens) {
|
| 1597 |
+
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
|
| 1598 |
}
|
| 1599 |
|
| 1600 |
// penalties
|
|
|
|
| 1723 |
float penalty_present) {
|
| 1724 |
penalty_last_n = std::max(penalty_last_n, 0);
|
| 1725 |
|
| 1726 |
+
return llama_sampler_init(
|
| 1727 |
/* .iface = */ &llama_sampler_penalties_i,
|
| 1728 |
/* .ctx = */ new llama_sampler_penalties {
|
| 1729 |
/* .penalty_last_n = */ penalty_last_n,
|
|
|
|
| 1732 |
/* .penalty_present = */ penalty_present,
|
| 1733 |
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
| 1734 |
/* .token_count = */ {},
|
| 1735 |
+
}
|
| 1736 |
+
);
|
| 1737 |
+
}
|
| 1738 |
+
|
| 1739 |
+
// top-n-sigma
|
| 1740 |
+
|
| 1741 |
+
struct llama_sampler_top_n_sigma {
|
| 1742 |
+
const float n;
|
| 1743 |
+
};
|
| 1744 |
+
|
| 1745 |
+
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
| 1746 |
+
return "top-n-sigma";
|
| 1747 |
+
}
|
| 1748 |
+
|
| 1749 |
+
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
| 1750 |
+
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
| 1751 |
+
|
| 1752 |
+
// find max logit and calculate mean
|
| 1753 |
+
float max = cur_p->data[0].logit;
|
| 1754 |
+
float logits_sum = 0;
|
| 1755 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1756 |
+
if (cur_p->data[i].logit > max) {
|
| 1757 |
+
max = cur_p->data[i].logit;
|
| 1758 |
+
}
|
| 1759 |
+
logits_sum += cur_p->data[i].logit;
|
| 1760 |
+
}
|
| 1761 |
+
float mean = logits_sum/cur_p->size;
|
| 1762 |
+
|
| 1763 |
+
// calculate standard deviation
|
| 1764 |
+
float acc = 0;
|
| 1765 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1766 |
+
acc += pow(cur_p->data[i].logit - mean, 2);
|
| 1767 |
+
}
|
| 1768 |
+
float std = sqrt(acc/cur_p->size);
|
| 1769 |
+
|
| 1770 |
+
//apply mask
|
| 1771 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
| 1772 |
+
if (cur_p->data[i].logit < max - (ctx->n * std)) {
|
| 1773 |
+
cur_p->data[i].logit = -INFINITY;
|
| 1774 |
+
}
|
| 1775 |
+
}
|
| 1776 |
+
llama_sampler_softmax_impl(cur_p);
|
| 1777 |
+
}
|
| 1778 |
+
|
| 1779 |
+
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
| 1780 |
+
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
|
| 1781 |
+
return llama_sampler_init_top_n_sigma(ctx->n);
|
| 1782 |
+
}
|
| 1783 |
+
|
| 1784 |
+
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
| 1785 |
+
delete (llama_sampler_top_n_sigma *) smpl->ctx;
|
| 1786 |
+
}
|
| 1787 |
+
|
| 1788 |
+
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
| 1789 |
+
/* .name = */ llama_sampler_top_n_sigma_name,
|
| 1790 |
+
/* .accept = */ nullptr,
|
| 1791 |
+
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
| 1792 |
+
/* .reset = */ nullptr,
|
| 1793 |
+
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
| 1794 |
+
/* .free = */ llama_sampler_top_n_sigma_free,
|
| 1795 |
+
};
|
| 1796 |
+
|
| 1797 |
+
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
| 1798 |
+
return llama_sampler_init(
|
| 1799 |
+
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
| 1800 |
+
/* .ctx = */ new llama_sampler_top_n_sigma {
|
| 1801 |
+
/* .n = */ n,
|
| 1802 |
+
}
|
| 1803 |
+
);
|
| 1804 |
}
|
| 1805 |
|
| 1806 |
// DRY
|
|
|
|
| 2153 |
}
|
| 2154 |
}
|
| 2155 |
|
| 2156 |
+
return llama_sampler_init(
|
| 2157 |
/* .iface = */ &llama_sampler_dry_i,
|
| 2158 |
/* .ctx = */ new llama_sampler_dry {
|
| 2159 |
/* .total_context_size = */ context_size,
|
|
|
|
| 2165 |
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
| 2166 |
/* .dry_max_token_repeat = */ {},
|
| 2167 |
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
| 2168 |
+
}
|
| 2169 |
+
);
|
| 2170 |
}
|
| 2171 |
|
| 2172 |
// wrapper for test-sampling.cpp
|
|
|
|
| 2267 |
int32_t n_vocab,
|
| 2268 |
int32_t n_logit_bias,
|
| 2269 |
const llama_logit_bias * logit_bias) {
|
| 2270 |
+
return llama_sampler_init(
|
| 2271 |
/* .iface = */ &llama_sampler_logit_bias_i,
|
| 2272 |
/* .ctx = */ new llama_sampler_logit_bias {
|
| 2273 |
/* .n_vocab = */ n_vocab,
|
| 2274 |
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
| 2275 |
/* .to_search = */ {},
|
| 2276 |
+
}
|
| 2277 |
+
);
|
| 2278 |
}
|
| 2279 |
|
| 2280 |
// infill
|
|
|
|
| 2489 |
};
|
| 2490 |
|
| 2491 |
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
| 2492 |
+
return llama_sampler_init(
|
| 2493 |
/* .iface = */ &llama_sampler_infill_i,
|
| 2494 |
/* .ctx = */ new llama_sampler_infill {
|
| 2495 |
/* .vocab = */ vocab,
|
| 2496 |
/* .buf0 = */ std::vector<char>(512),
|
| 2497 |
/* .buf1 = */ std::vector<char>(512),
|
| 2498 |
+
}
|
| 2499 |
+
);
|
| 2500 |
}
|
| 2501 |
|
| 2502 |
// utils
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
#include <queue>
|
| 17 |
#include <set>
|
| 18 |
#include <unordered_map>
|
|
|
|
| 19 |
|
| 20 |
//
|
| 21 |
// helpers
|
|
@@ -341,6 +342,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|
| 341 |
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
| 342 |
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
| 343 |
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
|
|
|
| 344 |
regex_exprs = {
|
| 345 |
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
| 346 |
};
|
|
@@ -392,6 +394,27 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|
| 392 |
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
| 393 |
};
|
| 394 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
default:
|
| 396 |
// default regex for BPE tokenization pre-processing
|
| 397 |
regex_exprs = {
|
|
@@ -1483,7 +1506,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1483 |
tokenizer_pre == "llama3" ||
|
| 1484 |
tokenizer_pre == "llama-v3" ||
|
| 1485 |
tokenizer_pre == "llama-bpe"||
|
| 1486 |
-
tokenizer_pre == "falcon3"
|
|
|
|
| 1487 |
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
| 1488 |
ignore_merges = true;
|
| 1489 |
add_bos = true;
|
|
@@ -1549,6 +1573,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1549 |
pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
|
| 1550 |
clean_spaces = false;
|
| 1551 |
} else if (
|
|
|
|
| 1552 |
tokenizer_pre == "chatglm-bpe") {
|
| 1553 |
pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
| 1554 |
special_bos_id = LLAMA_TOKEN_NULL;
|
|
@@ -1592,6 +1617,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1592 |
} else if (
|
| 1593 |
tokenizer_pre == "megrez") {
|
| 1594 |
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1595 |
} else {
|
| 1596 |
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
| 1597 |
}
|
|
@@ -1769,6 +1811,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1769 |
|| t.first == "<end_of_turn>"
|
| 1770 |
|| t.first == "<|endoftext|>"
|
| 1771 |
|| t.first == "<EOT>"
|
|
|
|
| 1772 |
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
| 1773 |
) {
|
| 1774 |
special_eot_id = t.second;
|
|
@@ -1799,8 +1842,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1799 |
if (false
|
| 1800 |
|| t.first == "<|fim_prefix|>" // Qwen
|
| 1801 |
|| t.first == "<fim-prefix>"
|
|
|
|
| 1802 |
|| t.first == "<|fim▁begin|>" // DeepSeek
|
| 1803 |
|| t.first == "<PRE>"
|
|
|
|
| 1804 |
) {
|
| 1805 |
special_fim_pre_id = t.second;
|
| 1806 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -1816,8 +1861,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1816 |
if (false
|
| 1817 |
|| t.first == "<|fim_suffix|>" // Qwen
|
| 1818 |
|| t.first == "<fim-suffix>"
|
|
|
|
| 1819 |
|| t.first == "<|fim▁hole|>" // DeepSeek
|
| 1820 |
|| t.first == "<SUF>"
|
|
|
|
| 1821 |
) {
|
| 1822 |
special_fim_suf_id = t.second;
|
| 1823 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -1833,8 +1880,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1833 |
if (false
|
| 1834 |
|| t.first == "<|fim_middle|>" // Qwen
|
| 1835 |
|| t.first == "<fim-middle>"
|
|
|
|
| 1836 |
|| t.first == "<|fim▁end|>" // DeepSeek
|
| 1837 |
|| t.first == "<MID>"
|
|
|
|
| 1838 |
) {
|
| 1839 |
special_fim_mid_id = t.second;
|
| 1840 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -1850,6 +1899,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1850 |
if (false
|
| 1851 |
|| t.first == "<|fim_pad|>" // Qwen
|
| 1852 |
|| t.first == "<fim-pad>"
|
|
|
|
| 1853 |
|| t.first == "<PAD>"
|
| 1854 |
) {
|
| 1855 |
special_fim_pad_id = t.second;
|
|
@@ -1868,6 +1918,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1868 |
|| t.first == "<|repo_name|>"
|
| 1869 |
|| t.first == "<fim-repo>"
|
| 1870 |
|| t.first == "<REPO>"
|
|
|
|
| 1871 |
) {
|
| 1872 |
special_fim_rep_id = t.second;
|
| 1873 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -1919,6 +1970,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1919 |
|| t.first == "<|endoftext|>"
|
| 1920 |
|| t.first == "<|eom_id|>"
|
| 1921 |
|| t.first == "<EOT>"
|
|
|
|
| 1922 |
) {
|
| 1923 |
special_eog_ids.insert(t.second);
|
| 1924 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -2177,14 +2229,12 @@ void llama_vocab::impl::tokenizer_st_partition(std::forward_list<fragment_buffer
|
|
| 2177 |
// find the first occurrence of a given special token in this fragment
|
| 2178 |
// passing offset argument only limit the "search area" but match coordinates
|
| 2179 |
// are still relative to the source full raw_text
|
| 2180 |
-
|
|
|
|
| 2181 |
|
| 2182 |
// no occurrences found, stop processing this fragment for a given special token
|
| 2183 |
if (match == std::string::npos) break;
|
| 2184 |
|
| 2185 |
-
// check if match is within bounds of offset <-> length
|
| 2186 |
-
if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
|
| 2187 |
-
|
| 2188 |
#ifdef PRETOKENIZERDEBUG
|
| 2189 |
LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
|
| 2190 |
#endif
|
|
|
|
| 16 |
#include <queue>
|
| 17 |
#include <set>
|
| 18 |
#include <unordered_map>
|
| 19 |
+
#include <cctype>
|
| 20 |
|
| 21 |
//
|
| 22 |
// helpers
|
|
|
|
| 342 |
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
| 343 |
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
| 344 |
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
| 345 |
+
case LLAMA_VOCAB_PRE_TYPE_TRILLION:
|
| 346 |
regex_exprs = {
|
| 347 |
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
| 348 |
};
|
|
|
|
| 394 |
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
| 395 |
};
|
| 396 |
break;
|
| 397 |
+
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
|
| 398 |
+
regex_exprs = {
|
| 399 |
+
// original regex from tokenizer.json
|
| 400 |
+
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
| 401 |
+
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
| 402 |
+
};
|
| 403 |
+
break;
|
| 404 |
+
case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
|
| 405 |
+
regex_exprs = {
|
| 406 |
+
"\\p{N}+",
|
| 407 |
+
"(?=(\\d{3})+(?!\\d))",
|
| 408 |
+
};
|
| 409 |
+
break;
|
| 410 |
+
case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
|
| 411 |
+
regex_exprs = {
|
| 412 |
+
// original regex from tokenizer.json
|
| 413 |
+
// "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
|
| 414 |
+
// FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?)
|
| 415 |
+
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
| 416 |
+
};
|
| 417 |
+
break;
|
| 418 |
default:
|
| 419 |
// default regex for BPE tokenization pre-processing
|
| 420 |
regex_exprs = {
|
|
|
|
| 1506 |
tokenizer_pre == "llama3" ||
|
| 1507 |
tokenizer_pre == "llama-v3" ||
|
| 1508 |
tokenizer_pre == "llama-bpe"||
|
| 1509 |
+
tokenizer_pre == "falcon3" ||
|
| 1510 |
+
tokenizer_pre == "pixtral") {
|
| 1511 |
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
| 1512 |
ignore_merges = true;
|
| 1513 |
add_bos = true;
|
|
|
|
| 1573 |
pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
|
| 1574 |
clean_spaces = false;
|
| 1575 |
} else if (
|
| 1576 |
+
tokenizer_pre == "glm4" ||
|
| 1577 |
tokenizer_pre == "chatglm-bpe") {
|
| 1578 |
pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
| 1579 |
special_bos_id = LLAMA_TOKEN_NULL;
|
|
|
|
| 1617 |
} else if (
|
| 1618 |
tokenizer_pre == "megrez") {
|
| 1619 |
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
| 1620 |
+
} else if (
|
| 1621 |
+
tokenizer_pre == "gpt-4o" ||
|
| 1622 |
+
tokenizer_pre == "llama4") {
|
| 1623 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
| 1624 |
+
clean_spaces = false;
|
| 1625 |
+
} else if (
|
| 1626 |
+
tokenizer_pre == "superbpe") {
|
| 1627 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
|
| 1628 |
+
clean_spaces = false;
|
| 1629 |
+
} else if (
|
| 1630 |
+
tokenizer_pre == "trillion") {
|
| 1631 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
|
| 1632 |
+
clean_spaces = false;
|
| 1633 |
+
} else if (
|
| 1634 |
+
tokenizer_pre == "bailingmoe") {
|
| 1635 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
| 1636 |
+
clean_spaces = false;
|
| 1637 |
} else {
|
| 1638 |
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
| 1639 |
}
|
|
|
|
| 1811 |
|| t.first == "<end_of_turn>"
|
| 1812 |
|| t.first == "<|endoftext|>"
|
| 1813 |
|| t.first == "<EOT>"
|
| 1814 |
+
|| t.first == "_<EOT>"
|
| 1815 |
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
| 1816 |
) {
|
| 1817 |
special_eot_id = t.second;
|
|
|
|
| 1842 |
if (false
|
| 1843 |
|| t.first == "<|fim_prefix|>" // Qwen
|
| 1844 |
|| t.first == "<fim-prefix>"
|
| 1845 |
+
|| t.first == "<fim_prefix>" // Granite
|
| 1846 |
|| t.first == "<|fim▁begin|>" // DeepSeek
|
| 1847 |
|| t.first == "<PRE>"
|
| 1848 |
+
|| t.first == "▁<PRE>" // CodeLlama
|
| 1849 |
) {
|
| 1850 |
special_fim_pre_id = t.second;
|
| 1851 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 1861 |
if (false
|
| 1862 |
|| t.first == "<|fim_suffix|>" // Qwen
|
| 1863 |
|| t.first == "<fim-suffix>"
|
| 1864 |
+
|| t.first == "<fim_suffix>" // Granite
|
| 1865 |
|| t.first == "<|fim▁hole|>" // DeepSeek
|
| 1866 |
|| t.first == "<SUF>"
|
| 1867 |
+
|| t.first == "▁<SUF>" // CodeLlama
|
| 1868 |
) {
|
| 1869 |
special_fim_suf_id = t.second;
|
| 1870 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 1880 |
if (false
|
| 1881 |
|| t.first == "<|fim_middle|>" // Qwen
|
| 1882 |
|| t.first == "<fim-middle>"
|
| 1883 |
+
|| t.first == "<fim_middle>" // Granite
|
| 1884 |
|| t.first == "<|fim▁end|>" // DeepSeek
|
| 1885 |
|| t.first == "<MID>"
|
| 1886 |
+
|| t.first == "▁<MID>" // CodeLlama
|
| 1887 |
) {
|
| 1888 |
special_fim_mid_id = t.second;
|
| 1889 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 1899 |
if (false
|
| 1900 |
|| t.first == "<|fim_pad|>" // Qwen
|
| 1901 |
|| t.first == "<fim-pad>"
|
| 1902 |
+
|| t.first == "<fim_pad>" // Granite
|
| 1903 |
|| t.first == "<PAD>"
|
| 1904 |
) {
|
| 1905 |
special_fim_pad_id = t.second;
|
|
|
|
| 1918 |
|| t.first == "<|repo_name|>"
|
| 1919 |
|| t.first == "<fim-repo>"
|
| 1920 |
|| t.first == "<REPO>"
|
| 1921 |
+
|| t.first == "<reponame>" // Granite
|
| 1922 |
) {
|
| 1923 |
special_fim_rep_id = t.second;
|
| 1924 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 1970 |
|| t.first == "<|endoftext|>"
|
| 1971 |
|| t.first == "<|eom_id|>"
|
| 1972 |
|| t.first == "<EOT>"
|
| 1973 |
+
|| t.first == "_<EOT>"
|
| 1974 |
) {
|
| 1975 |
special_eog_ids.insert(t.second);
|
| 1976 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 2229 |
// find the first occurrence of a given special token in this fragment
|
| 2230 |
// passing offset argument only limit the "search area" but match coordinates
|
| 2231 |
// are still relative to the source full raw_text
|
| 2232 |
+
// string_view begins at pos 0 for the same reason
|
| 2233 |
+
auto match = std::string_view(raw_text.data(), raw_text_base_offset + raw_text_base_length).find(text, raw_text_base_offset);
|
| 2234 |
|
| 2235 |
// no occurrences found, stop processing this fragment for a given special token
|
| 2236 |
if (match == std::string::npos) break;
|
| 2237 |
|
|
|
|
|
|
|
|
|
|
| 2238 |
#ifdef PRETOKENIZERDEBUG
|
| 2239 |
LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
|
| 2240 |
#endif
|
examples/talk-llama/llama.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama.h
CHANGED
|
@@ -60,6 +60,7 @@ extern "C" {
|
|
| 60 |
struct llama_model;
|
| 61 |
struct llama_context;
|
| 62 |
struct llama_sampler;
|
|
|
|
| 63 |
|
| 64 |
typedef int32_t llama_pos;
|
| 65 |
typedef int32_t llama_token;
|
|
@@ -105,6 +106,12 @@ extern "C" {
|
|
| 105 |
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
| 106 |
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
| 107 |
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
};
|
| 109 |
|
| 110 |
enum llama_rope_type {
|
|
@@ -213,7 +220,7 @@ extern "C" {
|
|
| 213 |
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
|
| 214 |
};
|
| 215 |
|
| 216 |
-
// TODO: simplify (https://github.com/
|
| 217 |
typedef struct llama_token_data {
|
| 218 |
llama_token id; // token id
|
| 219 |
float logit; // log-odds of the token
|
|
@@ -275,10 +282,18 @@ extern "C" {
|
|
| 275 |
};
|
| 276 |
};
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
struct llama_model_params {
|
| 279 |
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
|
| 280 |
ggml_backend_dev_t * devices;
|
| 281 |
|
|
|
|
|
|
|
|
|
|
| 282 |
int32_t n_gpu_layers; // number of layers to store in VRAM
|
| 283 |
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
| 284 |
|
|
@@ -307,7 +322,7 @@ extern "C" {
|
|
| 307 |
};
|
| 308 |
|
| 309 |
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
| 310 |
-
// https://github.com/
|
| 311 |
struct llama_context_params {
|
| 312 |
uint32_t n_ctx; // text context, 0 = from model
|
| 313 |
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
|
@@ -320,7 +335,7 @@ extern "C" {
|
|
| 320 |
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
| 321 |
enum llama_attention_type attention_type; // attention type to use for embeddings
|
| 322 |
|
| 323 |
-
// ref: https://github.com/
|
| 324 |
float rope_freq_base; // RoPE base frequency, 0 = from model
|
| 325 |
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
| 326 |
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
|
|
@@ -353,17 +368,18 @@ extern "C" {
|
|
| 353 |
|
| 354 |
// model quantization parameters
|
| 355 |
typedef struct llama_model_quantize_params {
|
| 356 |
-
int32_t nthread;
|
| 357 |
-
enum llama_ftype ftype;
|
| 358 |
-
enum ggml_type output_tensor_type;
|
| 359 |
-
enum ggml_type token_embedding_type;
|
| 360 |
-
bool allow_requantize;
|
| 361 |
-
bool quantize_output_tensor;
|
| 362 |
-
bool only_copy;
|
| 363 |
-
bool pure;
|
| 364 |
-
bool keep_split;
|
| 365 |
-
void * imatrix;
|
| 366 |
-
void * kv_overrides;
|
|
|
|
| 367 |
} llama_model_quantize_params;
|
| 368 |
|
| 369 |
typedef struct llama_logit_bias {
|
|
@@ -385,7 +401,7 @@ extern "C" {
|
|
| 385 |
struct llama_adapter_lora;
|
| 386 |
|
| 387 |
// Helpers for getting default parameters
|
| 388 |
-
// TODO: update API to start accepting pointers to params structs (https://github.com/
|
| 389 |
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
| 390 |
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
| 391 |
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
|
@@ -468,7 +484,8 @@ extern "C" {
|
|
| 468 |
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
| 469 |
|
| 470 |
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
| 471 |
-
LLAMA_API
|
|
|
|
| 472 |
|
| 473 |
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
| 474 |
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
|
@@ -477,6 +494,7 @@ extern "C" {
|
|
| 477 |
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
| 478 |
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
| 479 |
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
|
|
|
| 480 |
|
| 481 |
// Get the model's RoPE frequency scaling factor
|
| 482 |
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
|
@@ -584,7 +602,7 @@ extern "C" {
|
|
| 584 |
// KV cache
|
| 585 |
//
|
| 586 |
|
| 587 |
-
// TODO:
|
| 588 |
|
| 589 |
// Information associated with an individual cell in the KV cache view.
|
| 590 |
struct llama_kv_cache_view_cell {
|
|
@@ -639,13 +657,19 @@ extern "C" {
|
|
| 639 |
|
| 640 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
| 641 |
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
| 642 |
-
LLAMA_API int32_t
|
|
|
|
|
|
|
|
|
|
| 643 |
|
| 644 |
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
| 645 |
-
LLAMA_API int32_t
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 648 |
-
LLAMA_API void
|
| 649 |
struct llama_context * ctx);
|
| 650 |
|
| 651 |
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
|
@@ -653,7 +677,7 @@ extern "C" {
|
|
| 653 |
// seq_id < 0 : match any sequence
|
| 654 |
// p0 < 0 : [0, p1]
|
| 655 |
// p1 < 0 : [p0, inf)
|
| 656 |
-
LLAMA_API bool
|
| 657 |
struct llama_context * ctx,
|
| 658 |
llama_seq_id seq_id,
|
| 659 |
llama_pos p0,
|
|
@@ -663,7 +687,7 @@ extern "C" {
|
|
| 663 |
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
| 664 |
// p0 < 0 : [0, p1]
|
| 665 |
// p1 < 0 : [p0, inf)
|
| 666 |
-
LLAMA_API void
|
| 667 |
struct llama_context * ctx,
|
| 668 |
llama_seq_id seq_id_src,
|
| 669 |
llama_seq_id seq_id_dst,
|
|
@@ -671,17 +695,17 @@ extern "C" {
|
|
| 671 |
llama_pos p1);
|
| 672 |
|
| 673 |
// Removes all tokens that do not belong to the specified sequence
|
| 674 |
-
LLAMA_API void
|
| 675 |
struct llama_context * ctx,
|
| 676 |
llama_seq_id seq_id);
|
| 677 |
|
| 678 |
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 679 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 680 |
// - lazily on next llama_decode()
|
| 681 |
-
// - explicitly with
|
| 682 |
// p0 < 0 : [0, p1]
|
| 683 |
// p1 < 0 : [p0, inf)
|
| 684 |
-
LLAMA_API void
|
| 685 |
struct llama_context * ctx,
|
| 686 |
llama_seq_id seq_id,
|
| 687 |
llama_pos p0,
|
|
@@ -691,10 +715,10 @@ extern "C" {
|
|
| 691 |
// Integer division of the positions by factor of `d > 1`
|
| 692 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 693 |
// - lazily on next llama_decode()
|
| 694 |
-
// - explicitly with
|
| 695 |
// p0 < 0 : [0, p1]
|
| 696 |
// p1 < 0 : [p0, inf)
|
| 697 |
-
LLAMA_API void
|
| 698 |
struct llama_context * ctx,
|
| 699 |
llama_seq_id seq_id,
|
| 700 |
llama_pos p0,
|
|
@@ -702,24 +726,76 @@ extern "C" {
|
|
| 702 |
int d);
|
| 703 |
|
| 704 |
// Returns the largest position present in the KV cache for the specified sequence
|
| 705 |
-
LLAMA_API llama_pos
|
| 706 |
struct llama_context * ctx,
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
|
| 710 |
-
// how to avoid this?
|
| 711 |
|
| 712 |
// Defragment the KV cache
|
| 713 |
// This will be applied:
|
| 714 |
// - lazily on next llama_decode()
|
| 715 |
-
// - explicitly with
|
| 716 |
-
LLAMA_API void
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 719 |
-
LLAMA_API void
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
|
| 721 |
-
// Check if the context supports KV cache shifting
|
| 722 |
-
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
|
| 723 |
|
| 724 |
//
|
| 725 |
// State / sessions
|
|
@@ -883,6 +959,10 @@ extern "C" {
|
|
| 883 |
// If set to true, the model will only attend to the past tokens
|
| 884 |
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
| 885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
// Set abort callback
|
| 887 |
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
| 888 |
|
|
@@ -1040,7 +1120,7 @@ extern "C" {
|
|
| 1040 |
|
| 1041 |
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
| 1042 |
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
| 1043 |
-
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/
|
| 1044 |
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
| 1045 |
/// @param chat Pointer to a list of multiple llama_chat_message
|
| 1046 |
/// @param n_msg Number of llama_chat_message in this chat
|
|
@@ -1114,11 +1194,12 @@ extern "C" {
|
|
| 1114 |
};
|
| 1115 |
|
| 1116 |
struct llama_sampler {
|
| 1117 |
-
struct llama_sampler_i
|
| 1118 |
-
llama_sampler_context_t
|
| 1119 |
};
|
| 1120 |
|
| 1121 |
// mirror of llama_sampler_i:
|
|
|
|
| 1122 |
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
| 1123 |
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
| 1124 |
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
|
@@ -1148,7 +1229,7 @@ extern "C" {
|
|
| 1148 |
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
| 1149 |
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
| 1150 |
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
| 1151 |
-
"will be removed in the future (see https://github.com/
|
| 1152 |
|
| 1153 |
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1154 |
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
|
@@ -1156,7 +1237,7 @@ extern "C" {
|
|
| 1156 |
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1157 |
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
| 1158 |
|
| 1159 |
-
/// @details Minimum P sampling as described in https://github.com/
|
| 1160 |
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
| 1161 |
|
| 1162 |
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
|
@@ -1171,6 +1252,9 @@ extern "C" {
|
|
| 1171 |
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
| 1172 |
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
| 1173 |
|
|
|
|
|
|
|
|
|
|
| 1174 |
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1175 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
| 1176 |
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
|
@@ -1194,22 +1278,38 @@ extern "C" {
|
|
| 1194 |
float tau,
|
| 1195 |
float eta);
|
| 1196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1197 |
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
| 1198 |
const struct llama_vocab * vocab,
|
| 1199 |
const char * grammar_str,
|
| 1200 |
const char * grammar_root);
|
| 1201 |
|
| 1202 |
-
|
| 1203 |
-
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
| 1204 |
-
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
| 1205 |
-
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
| 1206 |
const struct llama_vocab * vocab,
|
| 1207 |
const char * grammar_str,
|
| 1208 |
const char * grammar_root,
|
| 1209 |
const char ** trigger_words,
|
| 1210 |
size_t num_trigger_words,
|
| 1211 |
const llama_token * trigger_tokens,
|
| 1212 |
-
size_t num_trigger_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1213 |
|
| 1214 |
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
| 1215 |
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
|
|
|
| 60 |
struct llama_model;
|
| 61 |
struct llama_context;
|
| 62 |
struct llama_sampler;
|
| 63 |
+
struct llama_kv_cache;
|
| 64 |
|
| 65 |
typedef int32_t llama_pos;
|
| 66 |
typedef int32_t llama_token;
|
|
|
|
| 106 |
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
| 107 |
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
| 108 |
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
| 109 |
+
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
| 110 |
+
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
| 111 |
+
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
| 112 |
+
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
| 113 |
+
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
| 114 |
+
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
| 115 |
};
|
| 116 |
|
| 117 |
enum llama_rope_type {
|
|
|
|
| 220 |
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
|
| 221 |
};
|
| 222 |
|
| 223 |
+
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
| 224 |
typedef struct llama_token_data {
|
| 225 |
llama_token id; // token id
|
| 226 |
float logit; // log-odds of the token
|
|
|
|
| 282 |
};
|
| 283 |
};
|
| 284 |
|
| 285 |
+
struct llama_model_tensor_buft_override {
|
| 286 |
+
const char * pattern;
|
| 287 |
+
ggml_backend_buffer_type_t buft;
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
struct llama_model_params {
|
| 291 |
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
|
| 292 |
ggml_backend_dev_t * devices;
|
| 293 |
|
| 294 |
+
// NULL-terminated list of buffer types to use for tensors that match a pattern
|
| 295 |
+
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
|
| 296 |
+
|
| 297 |
int32_t n_gpu_layers; // number of layers to store in VRAM
|
| 298 |
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
| 299 |
|
|
|
|
| 322 |
};
|
| 323 |
|
| 324 |
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
| 325 |
+
// https://github.com/ggml-org/llama.cpp/pull/7544
|
| 326 |
struct llama_context_params {
|
| 327 |
uint32_t n_ctx; // text context, 0 = from model
|
| 328 |
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
|
|
|
| 335 |
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
| 336 |
enum llama_attention_type attention_type; // attention type to use for embeddings
|
| 337 |
|
| 338 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/2054
|
| 339 |
float rope_freq_base; // RoPE base frequency, 0 = from model
|
| 340 |
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
| 341 |
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
|
|
|
|
| 368 |
|
| 369 |
// model quantization parameters
|
| 370 |
typedef struct llama_model_quantize_params {
|
| 371 |
+
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
| 372 |
+
enum llama_ftype ftype; // quantize to this llama_ftype
|
| 373 |
+
enum ggml_type output_tensor_type; // output tensor type
|
| 374 |
+
enum ggml_type token_embedding_type; // token embeddings tensor type
|
| 375 |
+
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
| 376 |
+
bool quantize_output_tensor; // quantize output.weight
|
| 377 |
+
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
| 378 |
+
bool pure; // quantize all tensors to the default type
|
| 379 |
+
bool keep_split; // quantize to the same number of shards
|
| 380 |
+
void * imatrix; // pointer to importance matrix data
|
| 381 |
+
void * kv_overrides; // pointer to vector containing overrides
|
| 382 |
+
void * tensor_types; // pointer to vector containing tensor types
|
| 383 |
} llama_model_quantize_params;
|
| 384 |
|
| 385 |
typedef struct llama_logit_bias {
|
|
|
|
| 401 |
struct llama_adapter_lora;
|
| 402 |
|
| 403 |
// Helpers for getting default parameters
|
| 404 |
+
// TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172)
|
| 405 |
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
| 406 |
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
| 407 |
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
|
|
|
| 484 |
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
| 485 |
|
| 486 |
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
| 487 |
+
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
| 488 |
+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
| 489 |
|
| 490 |
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
| 491 |
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
|
|
|
| 494 |
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
| 495 |
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
| 496 |
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
| 497 |
+
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
| 498 |
|
| 499 |
// Get the model's RoPE frequency scaling factor
|
| 500 |
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
|
|
|
| 602 |
// KV cache
|
| 603 |
//
|
| 604 |
|
| 605 |
+
// TODO: start using struct llama_kv_cache
|
| 606 |
|
| 607 |
// Information associated with an individual cell in the KV cache view.
|
| 608 |
struct llama_kv_cache_view_cell {
|
|
|
|
| 657 |
|
| 658 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
| 659 |
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
| 660 |
+
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
|
| 661 |
+
|
| 662 |
+
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
| 663 |
+
"use llama_kv_self_n_tokens instead");
|
| 664 |
|
| 665 |
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
| 666 |
+
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
|
| 667 |
+
|
| 668 |
+
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
| 669 |
+
"use llama_kv_self_used_cells instead");
|
| 670 |
|
| 671 |
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 672 |
+
LLAMA_API void llama_kv_self_clear(
|
| 673 |
struct llama_context * ctx);
|
| 674 |
|
| 675 |
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
|
|
|
| 677 |
// seq_id < 0 : match any sequence
|
| 678 |
// p0 < 0 : [0, p1]
|
| 679 |
// p1 < 0 : [p0, inf)
|
| 680 |
+
LLAMA_API bool llama_kv_self_seq_rm(
|
| 681 |
struct llama_context * ctx,
|
| 682 |
llama_seq_id seq_id,
|
| 683 |
llama_pos p0,
|
|
|
|
| 687 |
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
| 688 |
// p0 < 0 : [0, p1]
|
| 689 |
// p1 < 0 : [p0, inf)
|
| 690 |
+
LLAMA_API void llama_kv_self_seq_cp(
|
| 691 |
struct llama_context * ctx,
|
| 692 |
llama_seq_id seq_id_src,
|
| 693 |
llama_seq_id seq_id_dst,
|
|
|
|
| 695 |
llama_pos p1);
|
| 696 |
|
| 697 |
// Removes all tokens that do not belong to the specified sequence
|
| 698 |
+
LLAMA_API void llama_kv_self_seq_keep(
|
| 699 |
struct llama_context * ctx,
|
| 700 |
llama_seq_id seq_id);
|
| 701 |
|
| 702 |
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 703 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 704 |
// - lazily on next llama_decode()
|
| 705 |
+
// - explicitly with llama_kv_self_update()
|
| 706 |
// p0 < 0 : [0, p1]
|
| 707 |
// p1 < 0 : [p0, inf)
|
| 708 |
+
LLAMA_API void llama_kv_self_seq_add(
|
| 709 |
struct llama_context * ctx,
|
| 710 |
llama_seq_id seq_id,
|
| 711 |
llama_pos p0,
|
|
|
|
| 715 |
// Integer division of the positions by factor of `d > 1`
|
| 716 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 717 |
// - lazily on next llama_decode()
|
| 718 |
+
// - explicitly with llama_kv_self_update()
|
| 719 |
// p0 < 0 : [0, p1]
|
| 720 |
// p1 < 0 : [p0, inf)
|
| 721 |
+
LLAMA_API void llama_kv_self_seq_div(
|
| 722 |
struct llama_context * ctx,
|
| 723 |
llama_seq_id seq_id,
|
| 724 |
llama_pos p0,
|
|
|
|
| 726 |
int d);
|
| 727 |
|
| 728 |
// Returns the largest position present in the KV cache for the specified sequence
|
| 729 |
+
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 730 |
struct llama_context * ctx,
|
| 731 |
+
llama_seq_id seq_id);
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
// Defragment the KV cache
|
| 734 |
// This will be applied:
|
| 735 |
// - lazily on next llama_decode()
|
| 736 |
+
// - explicitly with llama_kv_self_update()
|
| 737 |
+
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
| 738 |
+
|
| 739 |
+
// Check if the context supports KV cache shifting
|
| 740 |
+
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
| 741 |
|
| 742 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 743 |
+
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
| 744 |
+
|
| 745 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
| 746 |
+
struct llama_context * ctx),
|
| 747 |
+
"use llama_kv_self_clear instead");
|
| 748 |
+
|
| 749 |
+
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
| 750 |
+
struct llama_context * ctx,
|
| 751 |
+
llama_seq_id seq_id,
|
| 752 |
+
llama_pos p0,
|
| 753 |
+
llama_pos p1),
|
| 754 |
+
"use llama_kv_self_seq_rm instead");
|
| 755 |
+
|
| 756 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
| 757 |
+
struct llama_context * ctx,
|
| 758 |
+
llama_seq_id seq_id_src,
|
| 759 |
+
llama_seq_id seq_id_dst,
|
| 760 |
+
llama_pos p0,
|
| 761 |
+
llama_pos p1),
|
| 762 |
+
"use llama_kv_self_seq_cp instead");
|
| 763 |
+
|
| 764 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
| 765 |
+
struct llama_context * ctx,
|
| 766 |
+
llama_seq_id seq_id),
|
| 767 |
+
"use llama_kv_self_seq_keep instead");
|
| 768 |
+
|
| 769 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
| 770 |
+
struct llama_context * ctx,
|
| 771 |
+
llama_seq_id seq_id,
|
| 772 |
+
llama_pos p0,
|
| 773 |
+
llama_pos p1,
|
| 774 |
+
llama_pos delta),
|
| 775 |
+
"use llama_kv_self_seq_add instead");
|
| 776 |
+
|
| 777 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
| 778 |
+
struct llama_context * ctx,
|
| 779 |
+
llama_seq_id seq_id,
|
| 780 |
+
llama_pos p0,
|
| 781 |
+
llama_pos p1,
|
| 782 |
+
int d),
|
| 783 |
+
"use llama_kv_self_seq_div instead");
|
| 784 |
+
|
| 785 |
+
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
| 786 |
+
struct llama_context * ctx,
|
| 787 |
+
llama_seq_id seq_id),
|
| 788 |
+
"use llama_kv_self_seq_pos_max instead");
|
| 789 |
+
|
| 790 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
| 791 |
+
"use llama_kv_self_defrag instead");
|
| 792 |
+
|
| 793 |
+
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
| 794 |
+
"use llama_kv_self_can_shift instead");
|
| 795 |
+
|
| 796 |
+
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
| 797 |
+
"use llama_kv_self_update instead");
|
| 798 |
|
|
|
|
|
|
|
| 799 |
|
| 800 |
//
|
| 801 |
// State / sessions
|
|
|
|
| 959 |
// If set to true, the model will only attend to the past tokens
|
| 960 |
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
| 961 |
|
| 962 |
+
// Set whether the model is in warmup mode or not
|
| 963 |
+
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
|
| 964 |
+
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
|
| 965 |
+
|
| 966 |
// Set abort callback
|
| 967 |
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
| 968 |
|
|
|
|
| 1120 |
|
| 1121 |
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
| 1122 |
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
| 1123 |
+
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
| 1124 |
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
| 1125 |
/// @param chat Pointer to a list of multiple llama_chat_message
|
| 1126 |
/// @param n_msg Number of llama_chat_message in this chat
|
|
|
|
| 1194 |
};
|
| 1195 |
|
| 1196 |
struct llama_sampler {
|
| 1197 |
+
const struct llama_sampler_i * iface;
|
| 1198 |
+
llama_sampler_context_t ctx;
|
| 1199 |
};
|
| 1200 |
|
| 1201 |
// mirror of llama_sampler_i:
|
| 1202 |
+
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
| 1203 |
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
| 1204 |
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
| 1205 |
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
|
|
|
| 1229 |
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
| 1230 |
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
| 1231 |
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
| 1232 |
+
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
|
| 1233 |
|
| 1234 |
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1235 |
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
|
|
|
| 1237 |
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
| 1238 |
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
| 1239 |
|
| 1240 |
+
/// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841
|
| 1241 |
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
| 1242 |
|
| 1243 |
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
|
|
|
| 1252 |
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
| 1253 |
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
| 1254 |
|
| 1255 |
+
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
|
| 1256 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
|
| 1257 |
+
|
| 1258 |
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
| 1259 |
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
| 1260 |
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
|
|
|
| 1278 |
float tau,
|
| 1279 |
float eta);
|
| 1280 |
|
| 1281 |
+
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
|
| 1282 |
+
/// @param vocab The vocabulary that this grammar will be used with.
|
| 1283 |
+
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
|
| 1284 |
+
/// @param grammar_root The name of the start symbol for the grammar.
|
| 1285 |
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
| 1286 |
const struct llama_vocab * vocab,
|
| 1287 |
const char * grammar_str,
|
| 1288 |
const char * grammar_root);
|
| 1289 |
|
| 1290 |
+
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
|
|
|
|
|
|
|
|
|
| 1291 |
const struct llama_vocab * vocab,
|
| 1292 |
const char * grammar_str,
|
| 1293 |
const char * grammar_root,
|
| 1294 |
const char ** trigger_words,
|
| 1295 |
size_t num_trigger_words,
|
| 1296 |
const llama_token * trigger_tokens,
|
| 1297 |
+
size_t num_trigger_tokens),
|
| 1298 |
+
"use llama_sampler_init_grammar_lazy_patterns instead");
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
+
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
|
| 1302 |
+
/// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
|
| 1303 |
+
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
|
| 1304 |
+
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
| 1305 |
+
const struct llama_vocab * vocab,
|
| 1306 |
+
const char * grammar_str,
|
| 1307 |
+
const char * grammar_root,
|
| 1308 |
+
const char ** trigger_patterns,
|
| 1309 |
+
size_t num_trigger_patterns,
|
| 1310 |
+
const llama_token * trigger_tokens,
|
| 1311 |
+
size_t num_trigger_tokens);
|
| 1312 |
+
|
| 1313 |
|
| 1314 |
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
| 1315 |
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
examples/talk-llama/unicode.cpp
CHANGED
|
@@ -618,7 +618,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|
| 618 |
result.reserve(utf8.size());
|
| 619 |
size_t offset = 0;
|
| 620 |
while (offset < utf8.size()) {
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
}
|
| 623 |
return result;
|
| 624 |
}
|
|
@@ -701,7 +708,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|
| 701 |
const auto cpts = unicode_cpts_from_utf8(text);
|
| 702 |
|
| 703 |
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
| 704 |
-
// ref: https://github.com/
|
| 705 |
std::string text_collapsed;
|
| 706 |
if (need_collapse) {
|
| 707 |
// collapse all unicode categories
|
|
|
|
| 618 |
result.reserve(utf8.size());
|
| 619 |
size_t offset = 0;
|
| 620 |
while (offset < utf8.size()) {
|
| 621 |
+
try {
|
| 622 |
+
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
| 623 |
+
}
|
| 624 |
+
catch (const std::invalid_argument & /*ex*/) {
|
| 625 |
+
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
|
| 626 |
+
++offset;
|
| 627 |
+
result.emplace_back(0xFFFD); // replacement character
|
| 628 |
+
}
|
| 629 |
}
|
| 630 |
return result;
|
| 631 |
}
|
|
|
|
| 708 |
const auto cpts = unicode_cpts_from_utf8(text);
|
| 709 |
|
| 710 |
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
| 711 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
|
| 712 |
std::string text_collapsed;
|
| 713 |
if (need_collapse) {
|
| 714 |
// collapse all unicode categories
|