Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files
examples/talk-llama/llama.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama.h
CHANGED
|
@@ -37,9 +37,13 @@
|
|
| 37 |
|
| 38 |
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
| 39 |
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
|
|
|
| 40 |
|
| 41 |
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
| 42 |
-
#define LLAMA_SESSION_VERSION
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
#ifdef __cplusplus
|
| 45 |
extern "C" {
|
|
@@ -65,6 +69,23 @@ extern "C" {
|
|
| 65 |
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
| 66 |
};
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
// note: these values should be synchronized with ggml_rope
|
| 69 |
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
| 70 |
enum llama_rope_type {
|
|
@@ -118,6 +139,7 @@ extern "C" {
|
|
| 118 |
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
| 119 |
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
| 120 |
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
|
|
|
| 121 |
|
| 122 |
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
| 123 |
};
|
|
@@ -155,7 +177,7 @@ extern "C" {
|
|
| 155 |
bool sorted;
|
| 156 |
} llama_token_data_array;
|
| 157 |
|
| 158 |
-
typedef bool (*llama_progress_callback)(float progress, void *
|
| 159 |
|
| 160 |
// Input data for llama_decode
|
| 161 |
// A llama_batch object can contain input about one or many sequences
|
|
@@ -191,15 +213,19 @@ extern "C" {
|
|
| 191 |
LLAMA_KV_OVERRIDE_TYPE_INT,
|
| 192 |
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
| 193 |
LLAMA_KV_OVERRIDE_TYPE_BOOL,
|
|
|
|
| 194 |
};
|
| 195 |
|
| 196 |
struct llama_model_kv_override {
|
| 197 |
-
char key[128];
|
| 198 |
enum llama_model_kv_override_type tag;
|
|
|
|
|
|
|
|
|
|
| 199 |
union {
|
| 200 |
-
int64_t
|
| 201 |
-
double
|
| 202 |
-
bool
|
|
|
|
| 203 |
};
|
| 204 |
};
|
| 205 |
|
|
@@ -228,9 +254,10 @@ extern "C" {
|
|
| 228 |
const struct llama_model_kv_override * kv_overrides;
|
| 229 |
|
| 230 |
// Keep the booleans together to avoid misalignment during copy-by-value.
|
| 231 |
-
bool vocab_only;
|
| 232 |
-
bool use_mmap;
|
| 233 |
-
bool use_mlock;
|
|
|
|
| 234 |
};
|
| 235 |
|
| 236 |
struct llama_context_params {
|
|
@@ -266,6 +293,7 @@ extern "C" {
|
|
| 266 |
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
| 267 |
bool embeddings; // if true, extract embeddings (together with logits)
|
| 268 |
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
|
|
|
| 269 |
|
| 270 |
// Abort callback
|
| 271 |
// if it returns true, execution of llama_decode() will be aborted
|
|
@@ -284,6 +312,7 @@ extern "C" {
|
|
| 284 |
bool quantize_output_tensor; // quantize output.weight
|
| 285 |
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
| 286 |
bool pure; // quantize all tensors to the default type
|
|
|
|
| 287 |
void * imatrix; // pointer to importance matrix data
|
| 288 |
void * kv_overrides; // pointer to vector containing overrides
|
| 289 |
} llama_model_quantize_params;
|
|
@@ -386,8 +415,10 @@ extern "C" {
|
|
| 386 |
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
| 387 |
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
| 388 |
|
| 389 |
-
LLAMA_API enum
|
| 390 |
-
|
|
|
|
|
|
|
| 391 |
|
| 392 |
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
| 393 |
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
|
@@ -518,11 +549,12 @@ extern "C" {
|
|
| 518 |
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
| 519 |
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
| 520 |
|
| 521 |
-
// Clear the KV cache
|
| 522 |
LLAMA_API void llama_kv_cache_clear(
|
| 523 |
struct llama_context * ctx);
|
| 524 |
|
| 525 |
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
|
|
|
| 526 |
// seq_id < 0 : match any sequence
|
| 527 |
// p0 < 0 : [0, p1]
|
| 528 |
// p1 < 0 : [p0, inf)
|
|
@@ -594,35 +626,93 @@ extern "C" {
|
|
| 594 |
|
| 595 |
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
| 596 |
// and kv_cache) - will often be smaller after compacting tokens
|
| 597 |
-
LLAMA_API size_t
|
|
|
|
|
|
|
| 598 |
|
| 599 |
// Copies the state to the specified destination address.
|
| 600 |
// Destination needs to have allocated enough memory.
|
| 601 |
// Returns the number of bytes copied
|
| 602 |
-
LLAMA_API size_t
|
| 603 |
struct llama_context * ctx,
|
| 604 |
uint8_t * dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
// Set the state reading from the specified address
|
| 607 |
// Returns the number of bytes read
|
| 608 |
-
LLAMA_API size_t
|
| 609 |
struct llama_context * ctx,
|
| 610 |
const uint8_t * src);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
|
| 612 |
// Save/load session file
|
| 613 |
-
LLAMA_API bool
|
| 614 |
struct llama_context * ctx,
|
| 615 |
const char * path_session,
|
| 616 |
llama_token * tokens_out,
|
| 617 |
size_t n_token_capacity,
|
| 618 |
size_t * n_token_count_out);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
-
LLAMA_API bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
struct llama_context * ctx,
|
| 622 |
const char * path_session,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
const llama_token * tokens,
|
| 624 |
size_t n_token_count);
|
| 625 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
//
|
| 627 |
// Decoding
|
| 628 |
//
|
|
@@ -684,8 +774,9 @@ extern "C" {
|
|
| 684 |
// Cols: n_vocab
|
| 685 |
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
| 686 |
|
| 687 |
-
// Logits for the ith token. Equivalent to:
|
| 688 |
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
|
|
|
| 689 |
// returns NULL for invalid ids.
|
| 690 |
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
| 691 |
|
|
@@ -697,8 +788,9 @@ extern "C" {
|
|
| 697 |
// Otherwise, returns NULL.
|
| 698 |
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
| 699 |
|
| 700 |
-
// Get the embeddings for the ith token. Equivalent to:
|
| 701 |
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
|
|
|
| 702 |
// shape: [n_embd] (1-dimensional)
|
| 703 |
// returns NULL for invalid ids.
|
| 704 |
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
|
@@ -718,9 +810,14 @@ extern "C" {
|
|
| 718 |
|
| 719 |
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
| 720 |
|
|
|
|
|
|
|
|
|
|
| 721 |
// Special tokens
|
| 722 |
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
| 723 |
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
|
|
|
|
|
|
| 724 |
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
| 725 |
|
| 726 |
// Returns -1 if unknown, 1 for true or 0 for false.
|
|
@@ -729,7 +826,7 @@ extern "C" {
|
|
| 729 |
// Returns -1 if unknown, 1 for true or 0 for false.
|
| 730 |
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
|
| 731 |
|
| 732 |
-
//
|
| 733 |
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
| 734 |
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
| 735 |
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
|
|
@@ -743,26 +840,28 @@ extern "C" {
|
|
| 743 |
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
| 744 |
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
| 745 |
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
| 746 |
-
/// @param
|
| 747 |
-
///
|
| 748 |
LLAMA_API int32_t llama_tokenize(
|
| 749 |
const struct llama_model * model,
|
| 750 |
const char * text,
|
| 751 |
int32_t text_len,
|
| 752 |
llama_token * tokens,
|
| 753 |
int32_t n_tokens_max,
|
| 754 |
-
bool
|
| 755 |
-
bool
|
| 756 |
|
| 757 |
// Token Id -> Piece.
|
| 758 |
// Uses the vocabulary in the provided context.
|
| 759 |
// Does not write null terminator to the buffer.
|
| 760 |
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
|
|
|
| 761 |
LLAMA_API int32_t llama_token_to_piece(
|
| 762 |
const struct llama_model * model,
|
| 763 |
llama_token token,
|
| 764 |
char * buf,
|
| 765 |
-
int32_t length
|
|
|
|
| 766 |
|
| 767 |
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
| 768 |
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
|
@@ -915,7 +1014,7 @@ extern "C" {
|
|
| 915 |
struct llama_context * ctx,
|
| 916 |
llama_token_data_array * candidates);
|
| 917 |
|
| 918 |
-
/// @details Randomly selects a token from the candidates based on their probabilities.
|
| 919 |
LLAMA_API llama_token llama_sample_token(
|
| 920 |
struct llama_context * ctx,
|
| 921 |
llama_token_data_array * candidates);
|
|
@@ -1002,8 +1101,9 @@ extern "C" {
|
|
| 1002 |
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
| 1003 |
#ifdef LLAMA_API_INTERNAL
|
| 1004 |
|
| 1005 |
-
#include <
|
| 1006 |
#include <string>
|
|
|
|
| 1007 |
|
| 1008 |
struct ggml_tensor;
|
| 1009 |
|
|
@@ -1030,15 +1130,20 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
|
|
| 1030 |
struct llama_context * ctx
|
| 1031 |
);
|
| 1032 |
|
| 1033 |
-
|
| 1034 |
const std::vector<std::vector<llama_grammar_element>> & rules,
|
| 1035 |
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
| 1036 |
-
const uint32_t chr
|
|
|
|
| 1037 |
|
| 1038 |
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
| 1039 |
const std::string & src,
|
| 1040 |
llama_partial_utf8 partial_start);
|
| 1041 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
#endif // LLAMA_API_INTERNAL
|
| 1043 |
|
| 1044 |
#endif // LLAMA_H
|
|
|
|
| 37 |
|
| 38 |
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
| 39 |
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
| 40 |
+
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
| 41 |
|
| 42 |
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
| 43 |
+
#define LLAMA_SESSION_VERSION 6
|
| 44 |
+
|
| 45 |
+
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
| 46 |
+
#define LLAMA_STATE_SEQ_VERSION 1
|
| 47 |
|
| 48 |
#ifdef __cplusplus
|
| 49 |
extern "C" {
|
|
|
|
| 69 |
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
| 70 |
};
|
| 71 |
|
| 72 |
+
// pre-tokenization types
|
| 73 |
+
enum llama_vocab_pre_type {
|
| 74 |
+
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
| 75 |
+
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
| 76 |
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
| 77 |
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
| 78 |
+
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
| 79 |
+
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
| 80 |
+
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
| 81 |
+
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
| 82 |
+
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
| 83 |
+
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
| 84 |
+
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
|
| 85 |
+
LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
|
| 86 |
+
LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
// note: these values should be synchronized with ggml_rope
|
| 90 |
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
| 91 |
enum llama_rope_type {
|
|
|
|
| 139 |
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
| 140 |
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
| 141 |
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
| 142 |
+
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
| 143 |
|
| 144 |
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
| 145 |
};
|
|
|
|
| 177 |
bool sorted;
|
| 178 |
} llama_token_data_array;
|
| 179 |
|
| 180 |
+
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
| 181 |
|
| 182 |
// Input data for llama_decode
|
| 183 |
// A llama_batch object can contain input about one or many sequences
|
|
|
|
| 213 |
LLAMA_KV_OVERRIDE_TYPE_INT,
|
| 214 |
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
| 215 |
LLAMA_KV_OVERRIDE_TYPE_BOOL,
|
| 216 |
+
LLAMA_KV_OVERRIDE_TYPE_STR,
|
| 217 |
};
|
| 218 |
|
| 219 |
struct llama_model_kv_override {
|
|
|
|
| 220 |
enum llama_model_kv_override_type tag;
|
| 221 |
+
|
| 222 |
+
char key[128];
|
| 223 |
+
|
| 224 |
union {
|
| 225 |
+
int64_t val_i64;
|
| 226 |
+
double val_f64;
|
| 227 |
+
bool val_bool;
|
| 228 |
+
char val_str[128];
|
| 229 |
};
|
| 230 |
};
|
| 231 |
|
|
|
|
| 254 |
const struct llama_model_kv_override * kv_overrides;
|
| 255 |
|
| 256 |
// Keep the booleans together to avoid misalignment during copy-by-value.
|
| 257 |
+
bool vocab_only; // only load the vocabulary, no weights
|
| 258 |
+
bool use_mmap; // use mmap if possible
|
| 259 |
+
bool use_mlock; // force system to keep model in RAM
|
| 260 |
+
bool check_tensors; // validate model tensor data
|
| 261 |
};
|
| 262 |
|
| 263 |
struct llama_context_params {
|
|
|
|
| 293 |
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
| 294 |
bool embeddings; // if true, extract embeddings (together with logits)
|
| 295 |
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
| 296 |
+
bool flash_attn; // whether to use flash attention
|
| 297 |
|
| 298 |
// Abort callback
|
| 299 |
// if it returns true, execution of llama_decode() will be aborted
|
|
|
|
| 312 |
bool quantize_output_tensor; // quantize output.weight
|
| 313 |
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
| 314 |
bool pure; // quantize all tensors to the default type
|
| 315 |
+
bool keep_split; // quantize to the same number of shards
|
| 316 |
void * imatrix; // pointer to importance matrix data
|
| 317 |
void * kv_overrides; // pointer to vector containing overrides
|
| 318 |
} llama_model_quantize_params;
|
|
|
|
| 415 |
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
| 416 |
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
| 417 |
|
| 418 |
+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
| 419 |
+
|
| 420 |
+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
| 421 |
+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
| 422 |
|
| 423 |
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
| 424 |
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
|
|
|
| 549 |
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
| 550 |
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
| 551 |
|
| 552 |
+
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 553 |
LLAMA_API void llama_kv_cache_clear(
|
| 554 |
struct llama_context * ctx);
|
| 555 |
|
| 556 |
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 557 |
+
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
| 558 |
// seq_id < 0 : match any sequence
|
| 559 |
// p0 < 0 : [0, p1]
|
| 560 |
// p1 < 0 : [p0, inf)
|
|
|
|
| 626 |
|
| 627 |
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
| 628 |
// and kv_cache) - will often be smaller after compacting tokens
|
| 629 |
+
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
|
| 630 |
+
LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
|
| 631 |
+
"use llama_state_get_size instead");
|
| 632 |
|
| 633 |
// Copies the state to the specified destination address.
|
| 634 |
// Destination needs to have allocated enough memory.
|
| 635 |
// Returns the number of bytes copied
|
| 636 |
+
LLAMA_API size_t llama_state_get_data(
|
| 637 |
struct llama_context * ctx,
|
| 638 |
uint8_t * dst);
|
| 639 |
+
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
|
| 640 |
+
struct llama_context * ctx,
|
| 641 |
+
uint8_t * dst),
|
| 642 |
+
"use llama_state_get_data instead");
|
| 643 |
|
| 644 |
// Set the state reading from the specified address
|
| 645 |
// Returns the number of bytes read
|
| 646 |
+
LLAMA_API size_t llama_state_set_data(
|
| 647 |
struct llama_context * ctx,
|
| 648 |
const uint8_t * src);
|
| 649 |
+
LLAMA_API DEPRECATED(size_t llama_set_state_data(
|
| 650 |
+
struct llama_context * ctx,
|
| 651 |
+
const uint8_t * src),
|
| 652 |
+
"use llama_state_set_data instead");
|
| 653 |
|
| 654 |
// Save/load session file
|
| 655 |
+
LLAMA_API bool llama_state_load_file(
|
| 656 |
struct llama_context * ctx,
|
| 657 |
const char * path_session,
|
| 658 |
llama_token * tokens_out,
|
| 659 |
size_t n_token_capacity,
|
| 660 |
size_t * n_token_count_out);
|
| 661 |
+
LLAMA_API DEPRECATED(bool llama_load_session_file(
|
| 662 |
+
struct llama_context * ctx,
|
| 663 |
+
const char * path_session,
|
| 664 |
+
llama_token * tokens_out,
|
| 665 |
+
size_t n_token_capacity,
|
| 666 |
+
size_t * n_token_count_out),
|
| 667 |
+
"use llama_state_load_file instead");
|
| 668 |
|
| 669 |
+
LLAMA_API bool llama_state_save_file(
|
| 670 |
+
struct llama_context * ctx,
|
| 671 |
+
const char * path_session,
|
| 672 |
+
const llama_token * tokens,
|
| 673 |
+
size_t n_token_count);
|
| 674 |
+
LLAMA_API DEPRECATED(bool llama_save_session_file(
|
| 675 |
struct llama_context * ctx,
|
| 676 |
const char * path_session,
|
| 677 |
+
const llama_token * tokens,
|
| 678 |
+
size_t n_token_count),
|
| 679 |
+
"use llama_state_save_file instead");
|
| 680 |
+
|
| 681 |
+
// Get the exact size needed to copy the KV cache of a single sequence
|
| 682 |
+
LLAMA_API size_t llama_state_seq_get_size(
|
| 683 |
+
struct llama_context * ctx,
|
| 684 |
+
llama_seq_id seq_id);
|
| 685 |
+
|
| 686 |
+
// Copy the KV cache of a single sequence into the specified buffer
|
| 687 |
+
LLAMA_API size_t llama_state_seq_get_data(
|
| 688 |
+
struct llama_context * ctx,
|
| 689 |
+
uint8_t * dst,
|
| 690 |
+
llama_seq_id seq_id);
|
| 691 |
+
|
| 692 |
+
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
| 693 |
+
// Returns:
|
| 694 |
+
// - Positive: Ok
|
| 695 |
+
// - Zero: Failed to load
|
| 696 |
+
LLAMA_API size_t llama_state_seq_set_data(
|
| 697 |
+
struct llama_context * ctx,
|
| 698 |
+
const uint8_t * src,
|
| 699 |
+
llama_seq_id dest_seq_id);
|
| 700 |
+
|
| 701 |
+
LLAMA_API size_t llama_state_seq_save_file(
|
| 702 |
+
struct llama_context * ctx,
|
| 703 |
+
const char * filepath,
|
| 704 |
+
llama_seq_id seq_id,
|
| 705 |
const llama_token * tokens,
|
| 706 |
size_t n_token_count);
|
| 707 |
|
| 708 |
+
LLAMA_API size_t llama_state_seq_load_file(
|
| 709 |
+
struct llama_context * ctx,
|
| 710 |
+
const char * filepath,
|
| 711 |
+
llama_seq_id dest_seq_id,
|
| 712 |
+
llama_token * tokens_out,
|
| 713 |
+
size_t n_token_capacity,
|
| 714 |
+
size_t * n_token_count_out);
|
| 715 |
+
|
| 716 |
//
|
| 717 |
// Decoding
|
| 718 |
//
|
|
|
|
| 774 |
// Cols: n_vocab
|
| 775 |
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
| 776 |
|
| 777 |
+
// Logits for the ith token. For positive indices, Equivalent to:
|
| 778 |
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
| 779 |
+
// Negative indicies can be used to access logits in reverse order, -1 is the last logit.
|
| 780 |
// returns NULL for invalid ids.
|
| 781 |
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
| 782 |
|
|
|
|
| 788 |
// Otherwise, returns NULL.
|
| 789 |
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
| 790 |
|
| 791 |
+
// Get the embeddings for the ith token. For positive indices, Equivalent to:
|
| 792 |
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
| 793 |
+
// Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
|
| 794 |
// shape: [n_embd] (1-dimensional)
|
| 795 |
// returns NULL for invalid ids.
|
| 796 |
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
|
|
|
| 810 |
|
| 811 |
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
| 812 |
|
| 813 |
+
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
|
| 814 |
+
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
|
| 815 |
+
|
| 816 |
// Special tokens
|
| 817 |
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
| 818 |
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
| 819 |
+
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
| 820 |
+
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
| 821 |
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
| 822 |
|
| 823 |
// Returns -1 if unknown, 1 for true or 0 for false.
|
|
|
|
| 826 |
// Returns -1 if unknown, 1 for true or 0 for false.
|
| 827 |
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
|
| 828 |
|
| 829 |
+
// Codellama infill tokens
|
| 830 |
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
| 831 |
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
| 832 |
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
|
|
|
|
| 840 |
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
| 841 |
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
| 842 |
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
| 843 |
+
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
| 844 |
+
/// as plaintext. Does not insert a leading space.
|
| 845 |
LLAMA_API int32_t llama_tokenize(
|
| 846 |
const struct llama_model * model,
|
| 847 |
const char * text,
|
| 848 |
int32_t text_len,
|
| 849 |
llama_token * tokens,
|
| 850 |
int32_t n_tokens_max,
|
| 851 |
+
bool add_special,
|
| 852 |
+
bool parse_special);
|
| 853 |
|
| 854 |
// Token Id -> Piece.
|
| 855 |
// Uses the vocabulary in the provided context.
|
| 856 |
// Does not write null terminator to the buffer.
|
| 857 |
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
| 858 |
+
// @param special If true, special tokens are rendered in the output.
|
| 859 |
LLAMA_API int32_t llama_token_to_piece(
|
| 860 |
const struct llama_model * model,
|
| 861 |
llama_token token,
|
| 862 |
char * buf,
|
| 863 |
+
int32_t length,
|
| 864 |
+
bool special);
|
| 865 |
|
| 866 |
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
| 867 |
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
|
|
|
| 1014 |
struct llama_context * ctx,
|
| 1015 |
llama_token_data_array * candidates);
|
| 1016 |
|
| 1017 |
+
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
|
| 1018 |
LLAMA_API llama_token llama_sample_token(
|
| 1019 |
struct llama_context * ctx,
|
| 1020 |
llama_token_data_array * candidates);
|
|
|
|
| 1101 |
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
| 1102 |
#ifdef LLAMA_API_INTERNAL
|
| 1103 |
|
| 1104 |
+
#include <random>
|
| 1105 |
#include <string>
|
| 1106 |
+
#include <vector>
|
| 1107 |
|
| 1108 |
struct ggml_tensor;
|
| 1109 |
|
|
|
|
| 1130 |
struct llama_context * ctx
|
| 1131 |
);
|
| 1132 |
|
| 1133 |
+
void llama_grammar_accept(
|
| 1134 |
const std::vector<std::vector<llama_grammar_element>> & rules,
|
| 1135 |
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
| 1136 |
+
const uint32_t chr,
|
| 1137 |
+
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
|
| 1138 |
|
| 1139 |
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
| 1140 |
const std::string & src,
|
| 1141 |
llama_partial_utf8 partial_start);
|
| 1142 |
|
| 1143 |
+
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
|
| 1144 |
+
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
| 1145 |
+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
| 1146 |
+
|
| 1147 |
#endif // LLAMA_API_INTERNAL
|
| 1148 |
|
| 1149 |
#endif // LLAMA_H
|
examples/talk-llama/talk-llama.cpp
CHANGED
|
@@ -35,10 +35,10 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
|
|
| 35 |
|
| 36 |
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
| 37 |
std::vector<char> result(8, 0);
|
| 38 |
-
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
|
| 39 |
if (n_tokens < 0) {
|
| 40 |
result.resize(-n_tokens);
|
| 41 |
-
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
|
| 42 |
GGML_ASSERT(check == -n_tokens);
|
| 43 |
} else {
|
| 44 |
result.resize(n_tokens);
|
|
|
|
| 35 |
|
| 36 |
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
| 37 |
std::vector<char> result(8, 0);
|
| 38 |
+
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), false);
|
| 39 |
if (n_tokens < 0) {
|
| 40 |
result.resize(-n_tokens);
|
| 41 |
+
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), false);
|
| 42 |
GGML_ASSERT(check == -n_tokens);
|
| 43 |
} else {
|
| 44 |
result.resize(n_tokens);
|
examples/talk-llama/unicode-data.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/unicode-data.h
CHANGED
|
@@ -5,12 +5,13 @@
|
|
| 5 |
#include <utility>
|
| 6 |
#include <vector>
|
| 7 |
|
| 8 |
-
extern const std::vector<std::pair<uint32_t, uint32_t>>
|
| 9 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
|
|
|
|
| 10 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
|
| 11 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
|
| 12 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
|
| 13 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
|
| 14 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
|
| 15 |
-
extern const std::multimap<uint32_t, uint32_t>
|
| 16 |
-
extern const std::map<char32_t, char32_t>
|
|
|
|
| 5 |
#include <utility>
|
| 6 |
#include <vector>
|
| 7 |
|
| 8 |
+
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
|
| 9 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
|
| 10 |
+
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
|
| 11 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
|
| 12 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
|
| 13 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
|
| 14 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
|
| 15 |
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
|
| 16 |
+
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
|
| 17 |
+
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
|
examples/talk-llama/unicode.cpp
CHANGED
|
@@ -5,11 +5,15 @@
|
|
| 5 |
#include <cstddef>
|
| 6 |
#include <cstdint>
|
| 7 |
#include <map>
|
|
|
|
| 8 |
#include <stdexcept>
|
| 9 |
#include <string>
|
| 10 |
#include <unordered_map>
|
|
|
|
| 11 |
#include <utility>
|
| 12 |
#include <vector>
|
|
|
|
|
|
|
| 13 |
|
| 14 |
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
| 15 |
std::string result;
|
|
@@ -53,23 +57,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
|
|
| 53 |
offset += 4;
|
| 54 |
return result;
|
| 55 |
}
|
| 56 |
-
throw std::invalid_argument("
|
| 57 |
}
|
| 58 |
|
| 59 |
-
static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
| 60 |
-
std::vector<uint16_t> result;
|
| 61 |
-
if (/* 0x0000 <= cp && */ cp <= 0xffff) {
|
| 62 |
-
result.emplace_back(cp);
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
result.emplace_back(
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
}
|
| 73 |
|
| 74 |
//static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
|
| 75 |
// std::vector<uint16_t> result;
|
|
@@ -80,56 +83,56 @@ static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
|
| 80 |
// return result;
|
| 81 |
//}
|
| 82 |
|
| 83 |
-
static uint32_t
|
| 84 |
-
assert(offset < utf16.size());
|
| 85 |
-
if (((utf16[0] >> 10) << 10) != 0xd800) {
|
| 86 |
-
auto result = utf16[offset + 0];
|
| 87 |
-
offset += 1;
|
| 88 |
-
return result;
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
|
| 92 |
-
throw std::invalid_argument("invalid character");
|
| 93 |
-
}
|
| 94 |
-
|
| 95 |
-
auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
|
| 96 |
-
offset += 2;
|
| 97 |
-
return result;
|
| 98 |
-
}
|
| 99 |
|
| 100 |
//static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
|
| 101 |
// std::vector<uint32_t> result;
|
| 102 |
// size_t offset = 0;
|
| 103 |
// while (offset < utf16.size()) {
|
| 104 |
-
// result.push_back(
|
| 105 |
// }
|
| 106 |
// return result;
|
| 107 |
//}
|
| 108 |
|
| 109 |
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
|
| 110 |
std::unordered_map<uint32_t, int> cpt_types;
|
| 111 |
-
for (auto p :
|
| 112 |
-
for (auto i = p.first; i <= p.second; ++
|
| 113 |
-
cpt_types[i] =
|
| 114 |
}
|
| 115 |
}
|
| 116 |
for (auto p : unicode_ranges_letter) {
|
| 117 |
-
for (auto i = p.first; i <= p.second; ++
|
| 118 |
cpt_types[i] = CODEPOINT_TYPE_LETTER;
|
| 119 |
}
|
| 120 |
}
|
| 121 |
-
for (auto p :
|
| 122 |
-
for (auto i = p.first; i <= p.second; ++
|
| 123 |
-
cpt_types[i] =
|
| 124 |
}
|
| 125 |
}
|
| 126 |
for (auto p : unicode_ranges_accent_mark) {
|
| 127 |
-
for (auto i = p.first; i <= p.second; ++
|
| 128 |
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
|
| 129 |
}
|
| 130 |
}
|
| 131 |
for (auto p : unicode_ranges_punctuation) {
|
| 132 |
-
for (auto i = p.first; i <= p.second; ++
|
| 133 |
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
|
| 134 |
}
|
| 135 |
}
|
|
@@ -139,7 +142,7 @@ static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
|
|
| 139 |
}
|
| 140 |
}
|
| 141 |
for (auto p : unicode_ranges_control) {
|
| 142 |
-
for (auto i = p.first; i <= p.second; ++
|
| 143 |
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
|
| 144 |
}
|
| 145 |
}
|
|
@@ -194,34 +197,395 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
|
| 194 |
return map;
|
| 195 |
}
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
//
|
| 198 |
// interface
|
| 199 |
//
|
| 200 |
|
| 201 |
std::string unicode_cpt_to_utf8(uint32_t cp) {
|
| 202 |
std::string result;
|
|
|
|
| 203 |
if (/* 0x00 <= cp && */ cp <= 0x7f) {
|
| 204 |
result.push_back(cp);
|
|
|
|
| 205 |
}
|
| 206 |
-
|
| 207 |
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
|
| 208 |
result.push_back(0x80 | (cp & 0x3f));
|
|
|
|
| 209 |
}
|
| 210 |
-
|
| 211 |
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
|
| 212 |
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
| 213 |
result.push_back(0x80 | (cp & 0x3f));
|
|
|
|
| 214 |
}
|
| 215 |
-
|
| 216 |
result.push_back(0xf0 | ((cp >> 18) & 0x07));
|
| 217 |
result.push_back(0x80 | ((cp >> 12) & 0x3f));
|
| 218 |
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
| 219 |
result.push_back(0x80 | (cp & 0x3f));
|
|
|
|
| 220 |
}
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
}
|
| 224 |
-
return result;
|
| 225 |
}
|
| 226 |
|
| 227 |
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
|
@@ -261,6 +625,19 @@ int unicode_cpt_type(const std::string & utf8) {
|
|
| 261 |
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
|
| 262 |
}
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
std::string unicode_byte_to_utf8(uint8_t byte) {
|
| 265 |
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
|
| 266 |
return map.at(byte);
|
|
@@ -275,3 +652,167 @@ char32_t unicode_tolower(char32_t cp) {
|
|
| 275 |
auto it = unicode_map_lowercase.find(cp);
|
| 276 |
return it == unicode_map_lowercase.end() ? cp : it->second;
|
| 277 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
#include <cstddef>
|
| 6 |
#include <cstdint>
|
| 7 |
#include <map>
|
| 8 |
+
#include <regex>
|
| 9 |
#include <stdexcept>
|
| 10 |
#include <string>
|
| 11 |
#include <unordered_map>
|
| 12 |
+
#include <unordered_set>
|
| 13 |
#include <utility>
|
| 14 |
#include <vector>
|
| 15 |
+
#include <locale>
|
| 16 |
+
#include <codecvt>
|
| 17 |
|
| 18 |
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
| 19 |
std::string result;
|
|
|
|
| 57 |
offset += 4;
|
| 58 |
return result;
|
| 59 |
}
|
| 60 |
+
throw std::invalid_argument("failed to convert utf8 to codepoint");
|
| 61 |
}
|
| 62 |
|
| 63 |
+
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
| 64 |
+
// std::vector<uint16_t> result;
|
| 65 |
+
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
|
| 66 |
+
// result.emplace_back(cp);
|
| 67 |
+
// return result;
|
| 68 |
+
// }
|
| 69 |
+
// if (0x10000 <= cp && cp <= 0x10ffff) {
|
| 70 |
+
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
|
| 71 |
+
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
|
| 72 |
+
// return result;
|
| 73 |
+
// }
|
| 74 |
+
// throw std::invalid_argument("failed to convert codepoint to utf16");
|
| 75 |
+
//}
|
|
|
|
| 76 |
|
| 77 |
//static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
|
| 78 |
// std::vector<uint16_t> result;
|
|
|
|
| 83 |
// return result;
|
| 84 |
//}
|
| 85 |
|
| 86 |
+
//static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
|
| 87 |
+
// assert(offset < utf16.size());
|
| 88 |
+
// if (((utf16[0] >> 10) << 10) != 0xd800) {
|
| 89 |
+
// auto result = utf16[offset + 0];
|
| 90 |
+
// offset += 1;
|
| 91 |
+
// return result;
|
| 92 |
+
// }
|
| 93 |
+
//
|
| 94 |
+
// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
|
| 95 |
+
// throw std::invalid_argument("invalid character");
|
| 96 |
+
// }
|
| 97 |
+
//
|
| 98 |
+
// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
|
| 99 |
+
// offset += 2;
|
| 100 |
+
// return result;
|
| 101 |
+
//}
|
| 102 |
|
| 103 |
//static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
|
| 104 |
// std::vector<uint32_t> result;
|
| 105 |
// size_t offset = 0;
|
| 106 |
// while (offset < utf16.size()) {
|
| 107 |
+
// result.push_back(unicode_cpt_from_utf16(utf16, offset));
|
| 108 |
// }
|
| 109 |
// return result;
|
| 110 |
//}
|
| 111 |
|
| 112 |
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
|
| 113 |
std::unordered_map<uint32_t, int> cpt_types;
|
| 114 |
+
for (auto p : unicode_ranges_number) {
|
| 115 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 116 |
+
cpt_types[i] = CODEPOINT_TYPE_NUMBER;
|
| 117 |
}
|
| 118 |
}
|
| 119 |
for (auto p : unicode_ranges_letter) {
|
| 120 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 121 |
cpt_types[i] = CODEPOINT_TYPE_LETTER;
|
| 122 |
}
|
| 123 |
}
|
| 124 |
+
for (auto p : unicode_ranges_separator) {
|
| 125 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 126 |
+
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
|
| 127 |
}
|
| 128 |
}
|
| 129 |
for (auto p : unicode_ranges_accent_mark) {
|
| 130 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 131 |
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
|
| 132 |
}
|
| 133 |
}
|
| 134 |
for (auto p : unicode_ranges_punctuation) {
|
| 135 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 136 |
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
|
| 137 |
}
|
| 138 |
}
|
|
|
|
| 142 |
}
|
| 143 |
}
|
| 144 |
for (auto p : unicode_ranges_control) {
|
| 145 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 146 |
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
|
| 147 |
}
|
| 148 |
}
|
|
|
|
| 197 |
return map;
|
| 198 |
}
|
| 199 |
|
| 200 |
+
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
| 201 |
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
| 202 |
+
return conv.from_bytes(s);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
| 206 |
+
std::vector<std::string> bpe_encoded_words;
|
| 207 |
+
for (const auto & word : bpe_words) {
|
| 208 |
+
std::string text_utf;
|
| 209 |
+
auto utf_word = unicode_cpts_from_utf8(word);
|
| 210 |
+
for (size_t i = 0; i < utf_word.size(); ++i) {
|
| 211 |
+
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
std::string encoded_token;
|
| 215 |
+
for (char & c : text_utf) {
|
| 216 |
+
encoded_token += unicode_byte_to_utf8(c);
|
| 217 |
+
}
|
| 218 |
+
bpe_encoded_words.emplace_back(encoded_token);
|
| 219 |
+
}
|
| 220 |
+
return bpe_encoded_words;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
| 224 |
+
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
|
| 225 |
+
std::vector<size_t> bpe_offsets; // store the offset of each word
|
| 226 |
+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
| 227 |
+
|
| 228 |
+
const auto cpts = unicode_cpts_from_utf8(text);
|
| 229 |
+
|
| 230 |
+
size_t start = 0;
|
| 231 |
+
for (auto offset : offsets) {
|
| 232 |
+
const size_t offset_ini = start;
|
| 233 |
+
const size_t offset_end = start + offset;
|
| 234 |
+
assert(offset_end <= cpts.size());
|
| 235 |
+
start = offset_end;
|
| 236 |
+
|
| 237 |
+
auto _get_cpt = [&] (const size_t pos) -> char32_t {
|
| 238 |
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
| 239 |
+
};
|
| 240 |
+
|
| 241 |
+
auto _get_cpt_type = [&] (const size_t pos) -> int {
|
| 242 |
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
|
| 243 |
+
};
|
| 244 |
+
|
| 245 |
+
size_t _prev_end = offset_ini;
|
| 246 |
+
auto _add_token = [&] (const size_t end) -> size_t {
|
| 247 |
+
assert(_prev_end <= end && end <= offset_end);
|
| 248 |
+
size_t len = end - _prev_end;
|
| 249 |
+
if (len > 0) {
|
| 250 |
+
bpe_offsets.push_back(len);
|
| 251 |
+
}
|
| 252 |
+
_prev_end = end;
|
| 253 |
+
//if (len > 0) {
|
| 254 |
+
// std::string s = "";
|
| 255 |
+
// for(size_t p = end-len; p < end; p++)
|
| 256 |
+
// s += unicode_cpt_to_utf8(cpts[p]);
|
| 257 |
+
// printf(">>> '%s'\n", s.c_str());
|
| 258 |
+
//}
|
| 259 |
+
return len;
|
| 260 |
+
};
|
| 261 |
+
|
| 262 |
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
| 263 |
+
const char32_t cpt = _get_cpt(pos);
|
| 264 |
+
const int cpt_type = _get_cpt_type(pos);
|
| 265 |
+
|
| 266 |
+
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
| 267 |
+
if (cpt == '\'' && pos+1 < offset_end) {
|
| 268 |
+
char32_t cpt_next = _get_cpt(pos+1);
|
| 269 |
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
| 270 |
+
pos += _add_token(pos+2);
|
| 271 |
+
continue;
|
| 272 |
+
}
|
| 273 |
+
if (pos+2 < offset_end) {
|
| 274 |
+
char32_t cpt_next_next = _get_cpt(pos+2);
|
| 275 |
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
| 276 |
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
| 277 |
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
| 278 |
+
pos += _add_token(pos+3);
|
| 279 |
+
continue;
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
|
| 285 |
+
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
|
| 286 |
+
// regex: <space>?\p{L}+
|
| 287 |
+
if (cpt2_type == CODEPOINT_TYPE_LETTER) {
|
| 288 |
+
pos += (cpt == ' ');
|
| 289 |
+
while (cpt2_type == CODEPOINT_TYPE_LETTER) {
|
| 290 |
+
cpt2_type = _get_cpt_type(++pos);
|
| 291 |
+
}
|
| 292 |
+
_add_token(pos);
|
| 293 |
+
continue;
|
| 294 |
+
}
|
| 295 |
+
// regex: <space>?\p{N}+
|
| 296 |
+
if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
|
| 297 |
+
pos += (cpt == ' ');
|
| 298 |
+
while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
|
| 299 |
+
cpt2_type = _get_cpt_type(++pos);
|
| 300 |
+
}
|
| 301 |
+
_add_token(pos);
|
| 302 |
+
continue;
|
| 303 |
+
}
|
| 304 |
+
// regex: <space>?[^\s\p{L}\p{N}]+
|
| 305 |
+
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
| 306 |
+
pos += (cpt == ' ');
|
| 307 |
+
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
| 308 |
+
cpt2_type = _get_cpt_type(++pos);
|
| 309 |
+
cpt2 = _get_cpt(pos);
|
| 310 |
+
}
|
| 311 |
+
_add_token(pos);
|
| 312 |
+
continue;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
size_t num_whitespaces = 0;
|
| 316 |
+
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
|
| 317 |
+
num_whitespaces++;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// regex: \s+(?!\S)
|
| 321 |
+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
| 322 |
+
pos += num_whitespaces - 1;
|
| 323 |
+
_add_token(pos);
|
| 324 |
+
continue;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
// regex: \s+
|
| 328 |
+
if (num_whitespaces > 0) {
|
| 329 |
+
pos += num_whitespaces;
|
| 330 |
+
_add_token(pos);
|
| 331 |
+
continue;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// no matches
|
| 335 |
+
_add_token(++pos);
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
return bpe_offsets;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
| 343 |
+
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
|
| 344 |
+
std::vector<size_t> bpe_offsets; // store the offset of each word
|
| 345 |
+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
| 346 |
+
|
| 347 |
+
const auto cpts = unicode_cpts_from_utf8(text);
|
| 348 |
+
|
| 349 |
+
size_t start = 0;
|
| 350 |
+
for (auto offset : offsets) {
|
| 351 |
+
const size_t offset_ini = start;
|
| 352 |
+
const size_t offset_end = start + offset;
|
| 353 |
+
assert(offset_end <= cpts.size());
|
| 354 |
+
start = offset_end;
|
| 355 |
+
|
| 356 |
+
auto _get_cpt = [&] (const size_t pos) -> char32_t {
|
| 357 |
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
| 358 |
+
};
|
| 359 |
+
|
| 360 |
+
auto _get_cpt_type = [&] (const size_t pos) -> int {
|
| 361 |
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
|
| 362 |
+
};
|
| 363 |
+
|
| 364 |
+
size_t _prev_end = offset_ini;
|
| 365 |
+
auto _add_token = [&] (const size_t end) -> size_t {
|
| 366 |
+
assert(_prev_end <= end && end <= offset_end);
|
| 367 |
+
size_t len = end - _prev_end;
|
| 368 |
+
if (len > 0) {
|
| 369 |
+
bpe_offsets.push_back(len);
|
| 370 |
+
}
|
| 371 |
+
_prev_end = end;
|
| 372 |
+
//if (len > 0) {
|
| 373 |
+
// std::string s = "";
|
| 374 |
+
// for(size_t p = end-len; p < end; p++)
|
| 375 |
+
// s += unicode_cpt_to_utf8(cpts[p]);
|
| 376 |
+
// printf(">>> '%s'\n", s.c_str());
|
| 377 |
+
//}
|
| 378 |
+
return len;
|
| 379 |
+
};
|
| 380 |
+
|
| 381 |
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
| 382 |
+
const char32_t cpt = _get_cpt(pos);
|
| 383 |
+
const int cpt_type = _get_cpt_type(pos);
|
| 384 |
+
|
| 385 |
+
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
| 386 |
+
if (cpt == '\'' && pos+1 < offset_end) {
|
| 387 |
+
char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
|
| 388 |
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
| 389 |
+
pos += _add_token(pos+2);
|
| 390 |
+
continue;
|
| 391 |
+
}
|
| 392 |
+
if (pos+2 < offset_end) {
|
| 393 |
+
char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
|
| 394 |
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
| 395 |
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
| 396 |
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
| 397 |
+
pos += _add_token(pos+3);
|
| 398 |
+
continue;
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
|
| 404 |
+
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
|
| 405 |
+
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
|
| 406 |
+
pos++;
|
| 407 |
+
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
|
| 408 |
+
pos++;
|
| 409 |
+
}
|
| 410 |
+
_add_token(pos);
|
| 411 |
+
continue;
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
// regex: \p{N}{1,3}
|
| 416 |
+
if (cpt_type == CODEPOINT_TYPE_NUMBER) {
|
| 417 |
+
size_t ini = pos;
|
| 418 |
+
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
|
| 419 |
+
if (++pos - ini >= 3 ) {
|
| 420 |
+
_add_token(pos);
|
| 421 |
+
ini = pos;
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
_add_token(pos);
|
| 425 |
+
continue;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
| 429 |
+
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
|
| 430 |
+
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
|
| 431 |
+
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
| 432 |
+
pos += (cpt == ' ');
|
| 433 |
+
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
| 434 |
+
cpt2_type = _get_cpt_type(++pos);
|
| 435 |
+
cpt2 = _get_cpt(pos);
|
| 436 |
+
}
|
| 437 |
+
while (cpt2 == '\r' || cpt2 == '\n') {
|
| 438 |
+
cpt2 = _get_cpt(++pos);
|
| 439 |
+
}
|
| 440 |
+
_add_token(pos);
|
| 441 |
+
continue;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
size_t num_whitespaces = 0;
|
| 445 |
+
size_t last_end_r_or_n = 0;
|
| 446 |
+
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
|
| 447 |
+
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
| 448 |
+
if (cpt2 == '\r' || cpt2 == '\n') {
|
| 449 |
+
last_end_r_or_n = pos + num_whitespaces + 1;
|
| 450 |
+
}
|
| 451 |
+
num_whitespaces++;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
// regex: \s*[\r\n]+
|
| 455 |
+
if (last_end_r_or_n > 0) {
|
| 456 |
+
pos = last_end_r_or_n;
|
| 457 |
+
_add_token(pos);
|
| 458 |
+
continue;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
// regex: \s+(?!\S)
|
| 462 |
+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
| 463 |
+
pos += num_whitespaces - 1;
|
| 464 |
+
_add_token(pos);
|
| 465 |
+
continue;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
// regex: \s+
|
| 469 |
+
if (num_whitespaces > 0) {
|
| 470 |
+
pos += num_whitespaces;
|
| 471 |
+
_add_token(pos);
|
| 472 |
+
continue;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
// no matches
|
| 476 |
+
_add_token(++pos);
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
return bpe_offsets;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
// use std::wregex to split the text
|
| 484 |
+
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
|
| 485 |
+
std::wregex expr(regex_expr);
|
| 486 |
+
std::vector<size_t> bpe_offsets; // store the offset of each word
|
| 487 |
+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
| 488 |
+
size_t start = 0;
|
| 489 |
+
for (auto offset : offsets) {
|
| 490 |
+
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
| 491 |
+
std::wcregex_iterator end;
|
| 492 |
+
|
| 493 |
+
int64_t start_idx = 0;
|
| 494 |
+
while (it != end) {
|
| 495 |
+
std::wcmatch match = *it;
|
| 496 |
+
if (match.position() > start_idx) {
|
| 497 |
+
bpe_offsets.emplace_back(match.position() - start_idx);
|
| 498 |
+
}
|
| 499 |
+
bpe_offsets.emplace_back(match.length());
|
| 500 |
+
start_idx = match.position() + match.length();
|
| 501 |
+
++it;
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
if (start_idx < (int64_t) offset) {
|
| 505 |
+
bpe_offsets.emplace_back(offset - start_idx);
|
| 506 |
+
}
|
| 507 |
+
start += offset;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
return bpe_offsets;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
// use std::regex to split the text
|
| 514 |
+
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
| 515 |
+
std::regex expr(regex_expr);
|
| 516 |
+
std::vector<size_t> bpe_offsets; // store the offset of each word
|
| 517 |
+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
| 518 |
+
size_t start = 0;
|
| 519 |
+
for (auto offset : offsets) {
|
| 520 |
+
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
| 521 |
+
std::cregex_iterator end;
|
| 522 |
+
|
| 523 |
+
int64_t start_idx = 0;
|
| 524 |
+
while (it != end) {
|
| 525 |
+
std::cmatch match = *it;
|
| 526 |
+
if (match.position() > start_idx) {
|
| 527 |
+
bpe_offsets.emplace_back(match.position() - start_idx);
|
| 528 |
+
}
|
| 529 |
+
bpe_offsets.emplace_back(match.length());
|
| 530 |
+
start_idx = match.position() + match.length();
|
| 531 |
+
++it;
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
if (start_idx < (int64_t) offset) {
|
| 535 |
+
bpe_offsets.emplace_back(offset - start_idx);
|
| 536 |
+
}
|
| 537 |
+
start += offset;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
return bpe_offsets;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
| 544 |
+
std::vector<size_t> bpe_offsets;
|
| 545 |
+
|
| 546 |
+
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
| 547 |
+
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
| 548 |
+
} else if (
|
| 549 |
+
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
|
| 550 |
+
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
| 551 |
+
|
| 552 |
+
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
return bpe_offsets;
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
//
|
| 559 |
// interface
|
| 560 |
//
|
| 561 |
|
| 562 |
std::string unicode_cpt_to_utf8(uint32_t cp) {
|
| 563 |
std::string result;
|
| 564 |
+
|
| 565 |
if (/* 0x00 <= cp && */ cp <= 0x7f) {
|
| 566 |
result.push_back(cp);
|
| 567 |
+
return result;
|
| 568 |
}
|
| 569 |
+
if (0x80 <= cp && cp <= 0x7ff) {
|
| 570 |
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
|
| 571 |
result.push_back(0x80 | (cp & 0x3f));
|
| 572 |
+
return result;
|
| 573 |
}
|
| 574 |
+
if (0x800 <= cp && cp <= 0xffff) {
|
| 575 |
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
|
| 576 |
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
| 577 |
result.push_back(0x80 | (cp & 0x3f));
|
| 578 |
+
return result;
|
| 579 |
}
|
| 580 |
+
if (0x10000 <= cp && cp <= 0x10ffff) {
|
| 581 |
result.push_back(0xf0 | ((cp >> 18) & 0x07));
|
| 582 |
result.push_back(0x80 | ((cp >> 12) & 0x3f));
|
| 583 |
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
| 584 |
result.push_back(0x80 | (cp & 0x3f));
|
| 585 |
+
return result;
|
| 586 |
}
|
| 587 |
+
|
| 588 |
+
throw std::invalid_argument("invalid codepoint");
|
|
|
|
|
|
|
| 589 |
}
|
| 590 |
|
| 591 |
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
|
|
|
| 625 |
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
|
| 626 |
}
|
| 627 |
|
| 628 |
+
bool unicode_cpt_is_whitespace(uint32_t cp) {
|
| 629 |
+
static const std::unordered_set<uint32_t> is_whitespace = [] {
|
| 630 |
+
std::unordered_set<uint32_t> is_whitespace;
|
| 631 |
+
for (auto p : unicode_ranges_whitespace) {
|
| 632 |
+
for (auto i = p.first; i <= p.second; ++i) {
|
| 633 |
+
is_whitespace.insert(i);
|
| 634 |
+
}
|
| 635 |
+
}
|
| 636 |
+
return is_whitespace;
|
| 637 |
+
}();
|
| 638 |
+
return (bool)is_whitespace.count(cp);
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
std::string unicode_byte_to_utf8(uint8_t byte) {
|
| 642 |
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
|
| 643 |
return map.at(byte);
|
|
|
|
| 652 |
auto it = unicode_map_lowercase.find(cp);
|
| 653 |
return it == unicode_map_lowercase.end() ? cp : it->second;
|
| 654 |
}
|
| 655 |
+
|
| 656 |
+
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
| 657 |
+
// unicode categories
|
| 658 |
+
static const std::map<std::string, int> k_ucat_enum = {
|
| 659 |
+
{ "\\p{N}", CODEPOINT_TYPE_NUMBER },
|
| 660 |
+
{ "\\p{L}", CODEPOINT_TYPE_LETTER },
|
| 661 |
+
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
|
| 662 |
+
};
|
| 663 |
+
|
| 664 |
+
static const std::map<int, int> k_ucat_cpt = {
|
| 665 |
+
{ CODEPOINT_TYPE_NUMBER, 0xD1 },
|
| 666 |
+
{ CODEPOINT_TYPE_LETTER, 0xD2 },
|
| 667 |
+
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
|
| 668 |
+
};
|
| 669 |
+
|
| 670 |
+
static const std::map<int, std::string> k_ucat_map = {
|
| 671 |
+
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9
|
| 672 |
+
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
| 673 |
+
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
| 674 |
+
};
|
| 675 |
+
|
| 676 |
+
// compute collapsed codepoints only if needed by at least one regex
|
| 677 |
+
bool need_collapse = false;
|
| 678 |
+
for (auto & regex_expr : regex_exprs) {
|
| 679 |
+
// search for unicode categories
|
| 680 |
+
for (const auto & ucat : k_ucat_enum) {
|
| 681 |
+
if (std::string::npos != regex_expr.find(ucat.first)) {
|
| 682 |
+
need_collapse = true;
|
| 683 |
+
break;
|
| 684 |
+
}
|
| 685 |
+
}
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
const auto cpts = unicode_cpts_from_utf8(text);
|
| 689 |
+
|
| 690 |
+
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
| 691 |
+
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
| 692 |
+
std::string text_collapsed;
|
| 693 |
+
if (need_collapse) {
|
| 694 |
+
// collapse all unicode categories
|
| 695 |
+
text_collapsed.resize(cpts.size());
|
| 696 |
+
|
| 697 |
+
for (size_t i = 0; i < cpts.size(); ++i) {
|
| 698 |
+
// keep single-byte codepoints as is
|
| 699 |
+
if (cpts[i] < 128) {
|
| 700 |
+
text_collapsed[i] = cpts[i];
|
| 701 |
+
continue;
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
const int cpt_type = unicode_cpt_type(cpts[i]);
|
| 705 |
+
|
| 706 |
+
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
|
| 707 |
+
text_collapsed[i] = k_ucat_cpt.at(cpt_type);
|
| 708 |
+
} else {
|
| 709 |
+
text_collapsed[i] = (char) 0xD0; // fallback
|
| 710 |
+
}
|
| 711 |
+
}
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
std::vector<size_t> bpe_offsets = { cpts.size() };
|
| 715 |
+
|
| 716 |
+
for (auto & regex_expr : regex_exprs) {
|
| 717 |
+
// first, see if we have an efficient custom regex implementation
|
| 718 |
+
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
| 719 |
+
|
| 720 |
+
if (!tmp.empty()) {
|
| 721 |
+
bpe_offsets = std::move(tmp);
|
| 722 |
+
continue;
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
// fallback to general-purpose std::regex / std::wregex
|
| 726 |
+
try {
|
| 727 |
+
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
| 728 |
+
// with the corresponding collapsed representation
|
| 729 |
+
bool use_collapsed = false;
|
| 730 |
+
for (auto & ucat : k_ucat_enum) {
|
| 731 |
+
if (std::string::npos != regex_expr.find(ucat.first)) {
|
| 732 |
+
use_collapsed = true;
|
| 733 |
+
break;
|
| 734 |
+
}
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
if (use_collapsed) {
|
| 738 |
+
// sanity-check that the original regex does not contain any non-ASCII characters
|
| 739 |
+
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
| 740 |
+
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
| 741 |
+
if (cpts_regex[i] >= 128) {
|
| 742 |
+
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
|
| 743 |
+
}
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
// generate a collapsed representation of the regex
|
| 747 |
+
std::string regex_expr_collapsed;
|
| 748 |
+
|
| 749 |
+
// track if we are inside [], because nested [] are not allowed
|
| 750 |
+
bool inside = false;
|
| 751 |
+
for (size_t i = 0; i < regex_expr.size(); ++i) {
|
| 752 |
+
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
|
| 753 |
+
regex_expr_collapsed += '[';
|
| 754 |
+
inside = true;
|
| 755 |
+
continue;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
|
| 759 |
+
regex_expr_collapsed += ']';
|
| 760 |
+
inside = false;
|
| 761 |
+
continue;
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
|
| 765 |
+
regex_expr[i + 1] == 'p' &&
|
| 766 |
+
regex_expr[i + 2] == '{' &&
|
| 767 |
+
regex_expr[i + 4] == '}') {
|
| 768 |
+
const std::string pat = regex_expr.substr(i, 5);
|
| 769 |
+
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
|
| 770 |
+
if (!inside) {
|
| 771 |
+
regex_expr_collapsed += '[';
|
| 772 |
+
}
|
| 773 |
+
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
|
| 774 |
+
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
|
| 775 |
+
if (!inside) {
|
| 776 |
+
regex_expr_collapsed += ']';
|
| 777 |
+
}
|
| 778 |
+
i += 4;
|
| 779 |
+
continue;
|
| 780 |
+
}
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
regex_expr_collapsed += regex_expr[i];
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
| 787 |
+
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
| 788 |
+
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
| 789 |
+
} else {
|
| 790 |
+
// no unicode category used, we can use std::wregex directly
|
| 791 |
+
const std::wstring wtext = unicode_wstring_from_utf8(text);
|
| 792 |
+
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
| 793 |
+
|
| 794 |
+
//printf("text: %s\n", text.c_str());
|
| 795 |
+
//printf("regex_expr: %s\n", regex_expr.c_str());
|
| 796 |
+
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
| 797 |
+
}
|
| 798 |
+
} catch (std::regex_error & e) {
|
| 799 |
+
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
| 800 |
+
fprintf(stderr, "Regex error: %s\n", e.what());
|
| 801 |
+
throw std::runtime_error("Failed to process regex");
|
| 802 |
+
}
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
std::vector<std::string> bpe_words;
|
| 806 |
+
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
|
| 807 |
+
|
| 808 |
+
size_t start = 0;
|
| 809 |
+
for (size_t & offset : bpe_offsets) {
|
| 810 |
+
bpe_words.emplace_back();
|
| 811 |
+
for (size_t i = start; i < start + offset; ++i) {
|
| 812 |
+
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
| 813 |
+
}
|
| 814 |
+
start += offset;
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
return unicode_byte_encoding_process(bpe_words);
|
| 818 |
+
}
|
examples/talk-llama/unicode.h
CHANGED
|
@@ -5,9 +5,9 @@
|
|
| 5 |
#include <vector>
|
| 6 |
|
| 7 |
#define CODEPOINT_TYPE_UNIDENTIFIED 0
|
| 8 |
-
#define
|
| 9 |
#define CODEPOINT_TYPE_LETTER 2
|
| 10 |
-
#define
|
| 11 |
#define CODEPOINT_TYPE_ACCENT_MARK 4
|
| 12 |
#define CODEPOINT_TYPE_PUNCTUATION 5
|
| 13 |
#define CODEPOINT_TYPE_SYMBOL 6
|
|
@@ -21,8 +21,11 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
|
|
| 21 |
int unicode_cpt_type(uint32_t cp);
|
| 22 |
int unicode_cpt_type(const std::string & utf8);
|
| 23 |
|
|
|
|
|
|
|
| 24 |
std::string unicode_byte_to_utf8(uint8_t byte);
|
| 25 |
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
| 26 |
|
| 27 |
-
// simple tolower that only implements one-to-one mapping, not one-to-many
|
| 28 |
char32_t unicode_tolower(char32_t cp);
|
|
|
|
|
|
|
|
|
| 5 |
#include <vector>
|
| 6 |
|
| 7 |
#define CODEPOINT_TYPE_UNIDENTIFIED 0
|
| 8 |
+
#define CODEPOINT_TYPE_NUMBER 1
|
| 9 |
#define CODEPOINT_TYPE_LETTER 2
|
| 10 |
+
#define CODEPOINT_TYPE_SEPARATOR 3
|
| 11 |
#define CODEPOINT_TYPE_ACCENT_MARK 4
|
| 12 |
#define CODEPOINT_TYPE_PUNCTUATION 5
|
| 13 |
#define CODEPOINT_TYPE_SYMBOL 6
|
|
|
|
| 21 |
int unicode_cpt_type(uint32_t cp);
|
| 22 |
int unicode_cpt_type(const std::string & utf8);
|
| 23 |
|
| 24 |
+
bool unicode_cpt_is_whitespace(uint32_t cp);
|
| 25 |
+
|
| 26 |
std::string unicode_byte_to_utf8(uint8_t byte);
|
| 27 |
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
| 28 |
|
|
|
|
| 29 |
char32_t unicode_tolower(char32_t cp);
|
| 30 |
+
|
| 31 |
+
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|