ggerganov commited on
Commit
f5f68d6
·
1 Parent(s): 3ea4549

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -37,9 +37,13 @@
37
 
38
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
39
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 
40
 
41
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
42
- #define LLAMA_SESSION_VERSION 5
 
 
 
43
 
44
  #ifdef __cplusplus
45
  extern "C" {
@@ -65,6 +69,23 @@ extern "C" {
65
  LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
66
  };
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  // note: these values should be synchronized with ggml_rope
69
  // TODO: maybe move this enum to ggml.h (ggml_rope_type)
70
  enum llama_rope_type {
@@ -118,6 +139,7 @@ extern "C" {
118
  LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
119
  LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
120
  LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
 
121
 
122
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
123
  };
@@ -155,7 +177,7 @@ extern "C" {
155
  bool sorted;
156
  } llama_token_data_array;
157
 
158
- typedef bool (*llama_progress_callback)(float progress, void *ctx);
159
 
160
  // Input data for llama_decode
161
  // A llama_batch object can contain input about one or many sequences
@@ -191,15 +213,19 @@ extern "C" {
191
  LLAMA_KV_OVERRIDE_TYPE_INT,
192
  LLAMA_KV_OVERRIDE_TYPE_FLOAT,
193
  LLAMA_KV_OVERRIDE_TYPE_BOOL,
 
194
  };
195
 
196
  struct llama_model_kv_override {
197
- char key[128];
198
  enum llama_model_kv_override_type tag;
 
 
 
199
  union {
200
- int64_t int_value;
201
- double float_value;
202
- bool bool_value;
 
203
  };
204
  };
205
 
@@ -228,9 +254,10 @@ extern "C" {
228
  const struct llama_model_kv_override * kv_overrides;
229
 
230
  // Keep the booleans together to avoid misalignment during copy-by-value.
231
- bool vocab_only; // only load the vocabulary, no weights
232
- bool use_mmap; // use mmap if possible
233
- bool use_mlock; // force system to keep model in RAM
 
234
  };
235
 
236
  struct llama_context_params {
@@ -266,6 +293,7 @@ extern "C" {
266
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
267
  bool embeddings; // if true, extract embeddings (together with logits)
268
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
 
269
 
270
  // Abort callback
271
  // if it returns true, execution of llama_decode() will be aborted
@@ -284,6 +312,7 @@ extern "C" {
284
  bool quantize_output_tensor; // quantize output.weight
285
  bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
286
  bool pure; // quantize all tensors to the default type
 
287
  void * imatrix; // pointer to importance matrix data
288
  void * kv_overrides; // pointer to vector containing overrides
289
  } llama_model_quantize_params;
@@ -386,8 +415,10 @@ extern "C" {
386
  LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
387
  LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
388
 
389
- LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
390
- LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
 
 
391
 
392
  LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
393
  LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -518,11 +549,12 @@ extern "C" {
518
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
519
  LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
520
 
521
- // Clear the KV cache
522
  LLAMA_API void llama_kv_cache_clear(
523
  struct llama_context * ctx);
524
 
525
  // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
 
526
  // seq_id < 0 : match any sequence
527
  // p0 < 0 : [0, p1]
528
  // p1 < 0 : [p0, inf)
@@ -594,35 +626,93 @@ extern "C" {
594
 
595
  // Returns the maximum size in bytes of the state (rng, logits, embedding
596
  // and kv_cache) - will often be smaller after compacting tokens
597
- LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
 
 
598
 
599
  // Copies the state to the specified destination address.
600
  // Destination needs to have allocated enough memory.
601
  // Returns the number of bytes copied
602
- LLAMA_API size_t llama_copy_state_data(
603
  struct llama_context * ctx,
604
  uint8_t * dst);
 
 
 
 
605
 
606
  // Set the state reading from the specified address
607
  // Returns the number of bytes read
608
- LLAMA_API size_t llama_set_state_data(
609
  struct llama_context * ctx,
610
  const uint8_t * src);
 
 
 
 
611
 
612
  // Save/load session file
613
- LLAMA_API bool llama_load_session_file(
614
  struct llama_context * ctx,
615
  const char * path_session,
616
  llama_token * tokens_out,
617
  size_t n_token_capacity,
618
  size_t * n_token_count_out);
 
 
 
 
 
 
 
619
 
620
- LLAMA_API bool llama_save_session_file(
 
 
 
 
 
621
  struct llama_context * ctx,
622
  const char * path_session,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  const llama_token * tokens,
624
  size_t n_token_count);
625
 
 
 
 
 
 
 
 
 
626
  //
627
  // Decoding
628
  //
@@ -684,8 +774,9 @@ extern "C" {
684
  // Cols: n_vocab
685
  LLAMA_API float * llama_get_logits(struct llama_context * ctx);
686
 
687
- // Logits for the ith token. Equivalent to:
688
  // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
 
689
  // returns NULL for invalid ids.
690
  LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
691
 
@@ -697,8 +788,9 @@ extern "C" {
697
  // Otherwise, returns NULL.
698
  LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
699
 
700
- // Get the embeddings for the ith token. Equivalent to:
701
  // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
 
702
  // shape: [n_embd] (1-dimensional)
703
  // returns NULL for invalid ids.
704
  LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
@@ -718,9 +810,14 @@ extern "C" {
718
 
719
  LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
720
 
 
 
 
721
  // Special tokens
722
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
723
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
 
 
724
  LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
725
 
726
  // Returns -1 if unknown, 1 for true or 0 for false.
@@ -729,7 +826,7 @@ extern "C" {
729
  // Returns -1 if unknown, 1 for true or 0 for false.
730
  LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
731
 
732
- // codellama infill tokens
733
  LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
734
  LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
735
  LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@@ -743,26 +840,28 @@ extern "C" {
743
  /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
744
  /// @return Returns the number of tokens on success, no more than n_tokens_max
745
  /// @return Returns a negative number on failure - the number of tokens that would have been returned
746
- /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
747
- /// Does not insert a leading space.
748
  LLAMA_API int32_t llama_tokenize(
749
  const struct llama_model * model,
750
  const char * text,
751
  int32_t text_len,
752
  llama_token * tokens,
753
  int32_t n_tokens_max,
754
- bool add_bos,
755
- bool special);
756
 
757
  // Token Id -> Piece.
758
  // Uses the vocabulary in the provided context.
759
  // Does not write null terminator to the buffer.
760
  // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
 
761
  LLAMA_API int32_t llama_token_to_piece(
762
  const struct llama_model * model,
763
  llama_token token,
764
  char * buf,
765
- int32_t length);
 
766
 
767
  /// Apply chat template. Inspired by hf apply_chat_template() on python.
768
  /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
@@ -915,7 +1014,7 @@ extern "C" {
915
  struct llama_context * ctx,
916
  llama_token_data_array * candidates);
917
 
918
- /// @details Randomly selects a token from the candidates based on their probabilities.
919
  LLAMA_API llama_token llama_sample_token(
920
  struct llama_context * ctx,
921
  llama_token_data_array * candidates);
@@ -1002,8 +1101,9 @@ extern "C" {
1002
  // Internal API to be implemented by llama.cpp and used by tests/benchmarks only
1003
  #ifdef LLAMA_API_INTERNAL
1004
 
1005
- #include <vector>
1006
  #include <string>
 
1007
 
1008
  struct ggml_tensor;
1009
 
@@ -1030,15 +1130,20 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
1030
  struct llama_context * ctx
1031
  );
1032
 
1033
- std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
1034
  const std::vector<std::vector<llama_grammar_element>> & rules,
1035
  const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1036
- const uint32_t chr);
 
1037
 
1038
  std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
1039
  const std::string & src,
1040
  llama_partial_utf8 partial_start);
1041
 
 
 
 
 
1042
  #endif // LLAMA_API_INTERNAL
1043
 
1044
  #endif // LLAMA_H
 
37
 
38
  #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
39
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
40
+ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
41
 
42
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43
+ #define LLAMA_SESSION_VERSION 6
44
+
45
+ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
46
+ #define LLAMA_STATE_SEQ_VERSION 1
47
 
48
  #ifdef __cplusplus
49
  extern "C" {
 
69
  LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
70
  };
71
 
72
+ // pre-tokenization types
73
+ enum llama_vocab_pre_type {
74
+ LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
75
+ LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
76
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
77
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
78
+ LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
79
+ LLAMA_VOCAB_PRE_TYPE_MPT = 5,
80
+ LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
81
+ LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
82
+ LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
83
+ LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
84
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
85
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
86
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
87
+ };
88
+
89
  // note: these values should be synchronized with ggml_rope
90
  // TODO: maybe move this enum to ggml.h (ggml_rope_type)
91
  enum llama_rope_type {
 
139
  LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
140
  LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
141
  LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
142
+ LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
143
 
144
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
145
  };
 
177
  bool sorted;
178
  } llama_token_data_array;
179
 
180
+ typedef bool (*llama_progress_callback)(float progress, void * user_data);
181
 
182
  // Input data for llama_decode
183
  // A llama_batch object can contain input about one or many sequences
 
213
  LLAMA_KV_OVERRIDE_TYPE_INT,
214
  LLAMA_KV_OVERRIDE_TYPE_FLOAT,
215
  LLAMA_KV_OVERRIDE_TYPE_BOOL,
216
+ LLAMA_KV_OVERRIDE_TYPE_STR,
217
  };
218
 
219
  struct llama_model_kv_override {
 
220
  enum llama_model_kv_override_type tag;
221
+
222
+ char key[128];
223
+
224
  union {
225
+ int64_t val_i64;
226
+ double val_f64;
227
+ bool val_bool;
228
+ char val_str[128];
229
  };
230
  };
231
 
 
254
  const struct llama_model_kv_override * kv_overrides;
255
 
256
  // Keep the booleans together to avoid misalignment during copy-by-value.
257
+ bool vocab_only; // only load the vocabulary, no weights
258
+ bool use_mmap; // use mmap if possible
259
+ bool use_mlock; // force system to keep model in RAM
260
+ bool check_tensors; // validate model tensor data
261
  };
262
 
263
  struct llama_context_params {
 
293
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
294
  bool embeddings; // if true, extract embeddings (together with logits)
295
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
296
+ bool flash_attn; // whether to use flash attention
297
 
298
  // Abort callback
299
  // if it returns true, execution of llama_decode() will be aborted
 
312
  bool quantize_output_tensor; // quantize output.weight
313
  bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
314
  bool pure; // quantize all tensors to the default type
315
+ bool keep_split; // quantize to the same number of shards
316
  void * imatrix; // pointer to importance matrix data
317
  void * kv_overrides; // pointer to vector containing overrides
318
  } llama_model_quantize_params;
 
415
  LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
416
  LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
417
 
418
+ LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
419
+
420
+ LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
421
+ LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
422
 
423
  LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
424
  LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
 
549
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
550
  LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
551
 
552
+ // Clear the KV cache - both cell info is erased and KV data is zeroed
553
  LLAMA_API void llama_kv_cache_clear(
554
  struct llama_context * ctx);
555
 
556
  // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
557
+ // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
558
  // seq_id < 0 : match any sequence
559
  // p0 < 0 : [0, p1]
560
  // p1 < 0 : [p0, inf)
 
626
 
627
  // Returns the maximum size in bytes of the state (rng, logits, embedding
628
  // and kv_cache) - will often be smaller after compacting tokens
629
+ LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
630
+ LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
631
+ "use llama_state_get_size instead");
632
 
633
  // Copies the state to the specified destination address.
634
  // Destination needs to have allocated enough memory.
635
  // Returns the number of bytes copied
636
+ LLAMA_API size_t llama_state_get_data(
637
  struct llama_context * ctx,
638
  uint8_t * dst);
639
+ LLAMA_API DEPRECATED(size_t llama_copy_state_data(
640
+ struct llama_context * ctx,
641
+ uint8_t * dst),
642
+ "use llama_state_get_data instead");
643
 
644
  // Set the state reading from the specified address
645
  // Returns the number of bytes read
646
+ LLAMA_API size_t llama_state_set_data(
647
  struct llama_context * ctx,
648
  const uint8_t * src);
649
+ LLAMA_API DEPRECATED(size_t llama_set_state_data(
650
+ struct llama_context * ctx,
651
+ const uint8_t * src),
652
+ "use llama_state_set_data instead");
653
 
654
  // Save/load session file
655
+ LLAMA_API bool llama_state_load_file(
656
  struct llama_context * ctx,
657
  const char * path_session,
658
  llama_token * tokens_out,
659
  size_t n_token_capacity,
660
  size_t * n_token_count_out);
661
+ LLAMA_API DEPRECATED(bool llama_load_session_file(
662
+ struct llama_context * ctx,
663
+ const char * path_session,
664
+ llama_token * tokens_out,
665
+ size_t n_token_capacity,
666
+ size_t * n_token_count_out),
667
+ "use llama_state_load_file instead");
668
 
669
+ LLAMA_API bool llama_state_save_file(
670
+ struct llama_context * ctx,
671
+ const char * path_session,
672
+ const llama_token * tokens,
673
+ size_t n_token_count);
674
+ LLAMA_API DEPRECATED(bool llama_save_session_file(
675
  struct llama_context * ctx,
676
  const char * path_session,
677
+ const llama_token * tokens,
678
+ size_t n_token_count),
679
+ "use llama_state_save_file instead");
680
+
681
+ // Get the exact size needed to copy the KV cache of a single sequence
682
+ LLAMA_API size_t llama_state_seq_get_size(
683
+ struct llama_context * ctx,
684
+ llama_seq_id seq_id);
685
+
686
+ // Copy the KV cache of a single sequence into the specified buffer
687
+ LLAMA_API size_t llama_state_seq_get_data(
688
+ struct llama_context * ctx,
689
+ uint8_t * dst,
690
+ llama_seq_id seq_id);
691
+
692
+ // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
693
+ // Returns:
694
+ // - Positive: Ok
695
+ // - Zero: Failed to load
696
+ LLAMA_API size_t llama_state_seq_set_data(
697
+ struct llama_context * ctx,
698
+ const uint8_t * src,
699
+ llama_seq_id dest_seq_id);
700
+
701
+ LLAMA_API size_t llama_state_seq_save_file(
702
+ struct llama_context * ctx,
703
+ const char * filepath,
704
+ llama_seq_id seq_id,
705
  const llama_token * tokens,
706
  size_t n_token_count);
707
 
708
+ LLAMA_API size_t llama_state_seq_load_file(
709
+ struct llama_context * ctx,
710
+ const char * filepath,
711
+ llama_seq_id dest_seq_id,
712
+ llama_token * tokens_out,
713
+ size_t n_token_capacity,
714
+ size_t * n_token_count_out);
715
+
716
  //
717
  // Decoding
718
  //
 
774
  // Cols: n_vocab
775
  LLAMA_API float * llama_get_logits(struct llama_context * ctx);
776
 
777
+ // Logits for the ith token. For positive indices, Equivalent to:
778
  // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
779
+ // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
780
  // returns NULL for invalid ids.
781
  LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
782
 
 
788
  // Otherwise, returns NULL.
789
  LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
790
 
791
+ // Get the embeddings for the ith token. For positive indices, Equivalent to:
792
  // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
793
+ // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
794
  // shape: [n_embd] (1-dimensional)
795
  // returns NULL for invalid ids.
796
  LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
 
810
 
811
  LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
812
 
813
+ // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
814
+ LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
815
+
816
  // Special tokens
817
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
818
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
819
+ LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
820
+ LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
821
  LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
822
 
823
  // Returns -1 if unknown, 1 for true or 0 for false.
 
826
  // Returns -1 if unknown, 1 for true or 0 for false.
827
  LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
828
 
829
+ // Codellama infill tokens
830
  LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
831
  LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
832
  LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
 
840
  /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
841
  /// @return Returns the number of tokens on success, no more than n_tokens_max
842
  /// @return Returns a negative number on failure - the number of tokens that would have been returned
843
+ /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
844
+ /// as plaintext. Does not insert a leading space.
845
  LLAMA_API int32_t llama_tokenize(
846
  const struct llama_model * model,
847
  const char * text,
848
  int32_t text_len,
849
  llama_token * tokens,
850
  int32_t n_tokens_max,
851
+ bool add_special,
852
+ bool parse_special);
853
 
854
  // Token Id -> Piece.
855
  // Uses the vocabulary in the provided context.
856
  // Does not write null terminator to the buffer.
857
  // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
858
+ // @param special If true, special tokens are rendered in the output.
859
  LLAMA_API int32_t llama_token_to_piece(
860
  const struct llama_model * model,
861
  llama_token token,
862
  char * buf,
863
+ int32_t length,
864
+ bool special);
865
 
866
  /// Apply chat template. Inspired by hf apply_chat_template() on python.
867
  /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
 
1014
  struct llama_context * ctx,
1015
  llama_token_data_array * candidates);
1016
 
1017
+ /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
1018
  LLAMA_API llama_token llama_sample_token(
1019
  struct llama_context * ctx,
1020
  llama_token_data_array * candidates);
 
1101
  // Internal API to be implemented by llama.cpp and used by tests/benchmarks only
1102
  #ifdef LLAMA_API_INTERNAL
1103
 
1104
+ #include <random>
1105
  #include <string>
1106
+ #include <vector>
1107
 
1108
  struct ggml_tensor;
1109
 
 
1130
  struct llama_context * ctx
1131
  );
1132
 
1133
+ void llama_grammar_accept(
1134
  const std::vector<std::vector<llama_grammar_element>> & rules,
1135
  const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1136
+ const uint32_t chr,
1137
+ std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
1138
 
1139
  std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
1140
  const std::string & src,
1141
  llama_partial_utf8 partial_start);
1142
 
1143
+ // Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1144
+ // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
1145
+ llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
1146
+
1147
  #endif // LLAMA_API_INTERNAL
1148
 
1149
  #endif // LLAMA_H
examples/talk-llama/talk-llama.cpp CHANGED
@@ -35,10 +35,10 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
35
 
36
  std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
37
  std::vector<char> result(8, 0);
38
- const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
39
  if (n_tokens < 0) {
40
  result.resize(-n_tokens);
41
- int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
42
  GGML_ASSERT(check == -n_tokens);
43
  } else {
44
  result.resize(n_tokens);
 
35
 
36
  std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
37
  std::vector<char> result(8, 0);
38
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), false);
39
  if (n_tokens < 0) {
40
  result.resize(-n_tokens);
41
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), false);
42
  GGML_ASSERT(check == -n_tokens);
43
  } else {
44
  result.resize(n_tokens);
examples/talk-llama/unicode-data.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/unicode-data.h CHANGED
@@ -5,12 +5,13 @@
5
  #include <utility>
6
  #include <vector>
7
 
8
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_digit;
9
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
 
10
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
11
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
12
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
13
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
14
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
15
- extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
16
- extern const std::map<char32_t, char32_t> unicode_map_lowercase;
 
5
  #include <utility>
6
  #include <vector>
7
 
8
+ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
9
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
10
+ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
11
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
12
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
13
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
14
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
15
  extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
16
+ extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
17
+ extern const std::map<char32_t, char32_t> unicode_map_lowercase;
examples/talk-llama/unicode.cpp CHANGED
@@ -5,11 +5,15 @@
5
  #include <cstddef>
6
  #include <cstdint>
7
  #include <map>
 
8
  #include <stdexcept>
9
  #include <string>
10
  #include <unordered_map>
 
11
  #include <utility>
12
  #include <vector>
 
 
13
 
14
  static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
15
  std::string result;
@@ -53,23 +57,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
53
  offset += 4;
54
  return result;
55
  }
56
- throw std::invalid_argument("invalid string");
57
  }
58
 
59
- static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
60
- std::vector<uint16_t> result;
61
- if (/* 0x0000 <= cp && */ cp <= 0xffff) {
62
- result.emplace_back(cp);
63
- }
64
- else if (0x10000 <= cp && cp <= 0x10ffff) {
65
- result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
66
- result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
67
- }
68
- else {
69
- throw std::invalid_argument("invalid cpt");
70
- }
71
- return result;
72
- }
73
 
74
  //static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
75
  // std::vector<uint16_t> result;
@@ -80,56 +83,56 @@ static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
80
  // return result;
81
  //}
82
 
83
- static uint32_t cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
84
- assert(offset < utf16.size());
85
- if (((utf16[0] >> 10) << 10) != 0xd800) {
86
- auto result = utf16[offset + 0];
87
- offset += 1;
88
- return result;
89
- }
90
-
91
- if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
92
- throw std::invalid_argument("invalid character");
93
- }
94
-
95
- auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
96
- offset += 2;
97
- return result;
98
- }
99
 
100
  //static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
101
  // std::vector<uint32_t> result;
102
  // size_t offset = 0;
103
  // while (offset < utf16.size()) {
104
- // result.push_back(cpt_from_utf16(utf16, offset));
105
  // }
106
  // return result;
107
  //}
108
 
109
  static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
110
  std::unordered_map<uint32_t, int> cpt_types;
111
- for (auto p : unicode_ranges_digit) {
112
- for (auto i = p.first; i <= p.second; ++ i) {
113
- cpt_types[i] = CODEPOINT_TYPE_DIGIT;
114
  }
115
  }
116
  for (auto p : unicode_ranges_letter) {
117
- for (auto i = p.first; i <= p.second; ++ i) {
118
  cpt_types[i] = CODEPOINT_TYPE_LETTER;
119
  }
120
  }
121
- for (auto p : unicode_ranges_whitespace) {
122
- for (auto i = p.first; i <= p.second; ++ i) {
123
- cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
124
  }
125
  }
126
  for (auto p : unicode_ranges_accent_mark) {
127
- for (auto i = p.first; i <= p.second; ++ i) {
128
  cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
129
  }
130
  }
131
  for (auto p : unicode_ranges_punctuation) {
132
- for (auto i = p.first; i <= p.second; ++ i) {
133
  cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
134
  }
135
  }
@@ -139,7 +142,7 @@ static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
139
  }
140
  }
141
  for (auto p : unicode_ranges_control) {
142
- for (auto i = p.first; i <= p.second; ++ i) {
143
  cpt_types[i] = CODEPOINT_TYPE_CONTROL;
144
  }
145
  }
@@ -194,34 +197,395 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
194
  return map;
195
  }
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  //
198
  // interface
199
  //
200
 
201
  std::string unicode_cpt_to_utf8(uint32_t cp) {
202
  std::string result;
 
203
  if (/* 0x00 <= cp && */ cp <= 0x7f) {
204
  result.push_back(cp);
 
205
  }
206
- else if (0x80 <= cp && cp <= 0x7ff) {
207
  result.push_back(0xc0 | ((cp >> 6) & 0x1f));
208
  result.push_back(0x80 | (cp & 0x3f));
 
209
  }
210
- else if (0x800 <= cp && cp <= 0xffff) {
211
  result.push_back(0xe0 | ((cp >> 12) & 0x0f));
212
  result.push_back(0x80 | ((cp >> 6) & 0x3f));
213
  result.push_back(0x80 | (cp & 0x3f));
 
214
  }
215
- else if (0x10000 <= cp && cp <= 0x10ffff) {
216
  result.push_back(0xf0 | ((cp >> 18) & 0x07));
217
  result.push_back(0x80 | ((cp >> 12) & 0x3f));
218
  result.push_back(0x80 | ((cp >> 6) & 0x3f));
219
  result.push_back(0x80 | (cp & 0x3f));
 
220
  }
221
- else {
222
- throw std::invalid_argument("invalid codepoint");
223
- }
224
- return result;
225
  }
226
 
227
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
@@ -261,6 +625,19 @@ int unicode_cpt_type(const std::string & utf8) {
261
  return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
262
  }
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  std::string unicode_byte_to_utf8(uint8_t byte) {
265
  static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
266
  return map.at(byte);
@@ -275,3 +652,167 @@ char32_t unicode_tolower(char32_t cp) {
275
  auto it = unicode_map_lowercase.find(cp);
276
  return it == unicode_map_lowercase.end() ? cp : it->second;
277
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  #include <cstddef>
6
  #include <cstdint>
7
  #include <map>
8
+ #include <regex>
9
  #include <stdexcept>
10
  #include <string>
11
  #include <unordered_map>
12
+ #include <unordered_set>
13
  #include <utility>
14
  #include <vector>
15
+ #include <locale>
16
+ #include <codecvt>
17
 
18
  static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
19
  std::string result;
 
57
  offset += 4;
58
  return result;
59
  }
60
+ throw std::invalid_argument("failed to convert utf8 to codepoint");
61
  }
62
 
63
+ //static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
64
+ // std::vector<uint16_t> result;
65
+ // if (/* 0x0000 <= cp && */ cp <= 0xffff) {
66
+ // result.emplace_back(cp);
67
+ // return result;
68
+ // }
69
+ // if (0x10000 <= cp && cp <= 0x10ffff) {
70
+ // result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
71
+ // result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
72
+ // return result;
73
+ // }
74
+ // throw std::invalid_argument("failed to convert codepoint to utf16");
75
+ //}
 
76
 
77
  //static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
78
  // std::vector<uint16_t> result;
 
83
  // return result;
84
  //}
85
 
86
+ //static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
87
+ // assert(offset < utf16.size());
88
+ // if (((utf16[0] >> 10) << 10) != 0xd800) {
89
+ // auto result = utf16[offset + 0];
90
+ // offset += 1;
91
+ // return result;
92
+ // }
93
+ //
94
+ // if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
95
+ // throw std::invalid_argument("invalid character");
96
+ // }
97
+ //
98
+ // auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
99
+ // offset += 2;
100
+ // return result;
101
+ //}
102
 
103
  //static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
104
  // std::vector<uint32_t> result;
105
  // size_t offset = 0;
106
  // while (offset < utf16.size()) {
107
+ // result.push_back(unicode_cpt_from_utf16(utf16, offset));
108
  // }
109
  // return result;
110
  //}
111
 
112
  static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
113
  std::unordered_map<uint32_t, int> cpt_types;
114
+ for (auto p : unicode_ranges_number) {
115
+ for (auto i = p.first; i <= p.second; ++i) {
116
+ cpt_types[i] = CODEPOINT_TYPE_NUMBER;
117
  }
118
  }
119
  for (auto p : unicode_ranges_letter) {
120
+ for (auto i = p.first; i <= p.second; ++i) {
121
  cpt_types[i] = CODEPOINT_TYPE_LETTER;
122
  }
123
  }
124
+ for (auto p : unicode_ranges_separator) {
125
+ for (auto i = p.first; i <= p.second; ++i) {
126
+ cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
127
  }
128
  }
129
  for (auto p : unicode_ranges_accent_mark) {
130
+ for (auto i = p.first; i <= p.second; ++i) {
131
  cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
132
  }
133
  }
134
  for (auto p : unicode_ranges_punctuation) {
135
+ for (auto i = p.first; i <= p.second; ++i) {
136
  cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
137
  }
138
  }
 
142
  }
143
  }
144
  for (auto p : unicode_ranges_control) {
145
+ for (auto i = p.first; i <= p.second; ++i) {
146
  cpt_types[i] = CODEPOINT_TYPE_CONTROL;
147
  }
148
  }
 
197
  return map;
198
  }
199
 
200
+ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
201
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
202
+ return conv.from_bytes(s);
203
+ }
204
+
205
+ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
206
+ std::vector<std::string> bpe_encoded_words;
207
+ for (const auto & word : bpe_words) {
208
+ std::string text_utf;
209
+ auto utf_word = unicode_cpts_from_utf8(word);
210
+ for (size_t i = 0; i < utf_word.size(); ++i) {
211
+ text_utf += unicode_cpt_to_utf8(utf_word[i]);
212
+ }
213
+
214
+ std::string encoded_token;
215
+ for (char & c : text_utf) {
216
+ encoded_token += unicode_byte_to_utf8(c);
217
+ }
218
+ bpe_encoded_words.emplace_back(encoded_token);
219
+ }
220
+ return bpe_encoded_words;
221
+ }
222
+
223
+ // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
224
+ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
225
+ std::vector<size_t> bpe_offsets; // store the offset of each word
226
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
227
+
228
+ const auto cpts = unicode_cpts_from_utf8(text);
229
+
230
+ size_t start = 0;
231
+ for (auto offset : offsets) {
232
+ const size_t offset_ini = start;
233
+ const size_t offset_end = start + offset;
234
+ assert(offset_end <= cpts.size());
235
+ start = offset_end;
236
+
237
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
238
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
239
+ };
240
+
241
+ auto _get_cpt_type = [&] (const size_t pos) -> int {
242
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
243
+ };
244
+
245
+ size_t _prev_end = offset_ini;
246
+ auto _add_token = [&] (const size_t end) -> size_t {
247
+ assert(_prev_end <= end && end <= offset_end);
248
+ size_t len = end - _prev_end;
249
+ if (len > 0) {
250
+ bpe_offsets.push_back(len);
251
+ }
252
+ _prev_end = end;
253
+ //if (len > 0) {
254
+ // std::string s = "";
255
+ // for(size_t p = end-len; p < end; p++)
256
+ // s += unicode_cpt_to_utf8(cpts[p]);
257
+ // printf(">>> '%s'\n", s.c_str());
258
+ //}
259
+ return len;
260
+ };
261
+
262
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
263
+ const char32_t cpt = _get_cpt(pos);
264
+ const int cpt_type = _get_cpt_type(pos);
265
+
266
+ // regex: 's|'t|'re|'ve|'m|'ll|'d
267
+ if (cpt == '\'' && pos+1 < offset_end) {
268
+ char32_t cpt_next = _get_cpt(pos+1);
269
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
270
+ pos += _add_token(pos+2);
271
+ continue;
272
+ }
273
+ if (pos+2 < offset_end) {
274
+ char32_t cpt_next_next = _get_cpt(pos+2);
275
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
276
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
277
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
278
+ pos += _add_token(pos+3);
279
+ continue;
280
+ }
281
+ }
282
+ }
283
+
284
+ char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
285
+ int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
286
+ // regex: <space>?\p{L}+
287
+ if (cpt2_type == CODEPOINT_TYPE_LETTER) {
288
+ pos += (cpt == ' ');
289
+ while (cpt2_type == CODEPOINT_TYPE_LETTER) {
290
+ cpt2_type = _get_cpt_type(++pos);
291
+ }
292
+ _add_token(pos);
293
+ continue;
294
+ }
295
+ // regex: <space>?\p{N}+
296
+ if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
297
+ pos += (cpt == ' ');
298
+ while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
299
+ cpt2_type = _get_cpt_type(++pos);
300
+ }
301
+ _add_token(pos);
302
+ continue;
303
+ }
304
+ // regex: <space>?[^\s\p{L}\p{N}]+
305
+ if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
306
+ pos += (cpt == ' ');
307
+ while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
308
+ cpt2_type = _get_cpt_type(++pos);
309
+ cpt2 = _get_cpt(pos);
310
+ }
311
+ _add_token(pos);
312
+ continue;
313
+ }
314
+
315
+ size_t num_whitespaces = 0;
316
+ while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
317
+ num_whitespaces++;
318
+ }
319
+
320
+ // regex: \s+(?!\S)
321
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
322
+ pos += num_whitespaces - 1;
323
+ _add_token(pos);
324
+ continue;
325
+ }
326
+
327
+ // regex: \s+
328
+ if (num_whitespaces > 0) {
329
+ pos += num_whitespaces;
330
+ _add_token(pos);
331
+ continue;
332
+ }
333
+
334
+ // no matches
335
+ _add_token(++pos);
336
+ }
337
+ }
338
+
339
+ return bpe_offsets;
340
+ }
341
+
342
+ // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
343
+ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
344
+ std::vector<size_t> bpe_offsets; // store the offset of each word
345
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
346
+
347
+ const auto cpts = unicode_cpts_from_utf8(text);
348
+
349
+ size_t start = 0;
350
+ for (auto offset : offsets) {
351
+ const size_t offset_ini = start;
352
+ const size_t offset_end = start + offset;
353
+ assert(offset_end <= cpts.size());
354
+ start = offset_end;
355
+
356
+ auto _get_cpt = [&] (const size_t pos) -> char32_t {
357
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
358
+ };
359
+
360
+ auto _get_cpt_type = [&] (const size_t pos) -> int {
361
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
362
+ };
363
+
364
+ size_t _prev_end = offset_ini;
365
+ auto _add_token = [&] (const size_t end) -> size_t {
366
+ assert(_prev_end <= end && end <= offset_end);
367
+ size_t len = end - _prev_end;
368
+ if (len > 0) {
369
+ bpe_offsets.push_back(len);
370
+ }
371
+ _prev_end = end;
372
+ //if (len > 0) {
373
+ // std::string s = "";
374
+ // for(size_t p = end-len; p < end; p++)
375
+ // s += unicode_cpt_to_utf8(cpts[p]);
376
+ // printf(">>> '%s'\n", s.c_str());
377
+ //}
378
+ return len;
379
+ };
380
+
381
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
382
+ const char32_t cpt = _get_cpt(pos);
383
+ const int cpt_type = _get_cpt_type(pos);
384
+
385
+ // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
386
+ if (cpt == '\'' && pos+1 < offset_end) {
387
+ char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
388
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
389
+ pos += _add_token(pos+2);
390
+ continue;
391
+ }
392
+ if (pos+2 < offset_end) {
393
+ char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
394
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
395
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
396
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
397
+ pos += _add_token(pos+3);
398
+ continue;
399
+ }
400
+ }
401
+ }
402
+
403
+ // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
404
+ if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
405
+ if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
406
+ pos++;
407
+ while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
408
+ pos++;
409
+ }
410
+ _add_token(pos);
411
+ continue;
412
+ }
413
+ }
414
+
415
+ // regex: \p{N}{1,3}
416
+ if (cpt_type == CODEPOINT_TYPE_NUMBER) {
417
+ size_t ini = pos;
418
+ while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
419
+ if (++pos - ini >= 3 ) {
420
+ _add_token(pos);
421
+ ini = pos;
422
+ }
423
+ }
424
+ _add_token(pos);
425
+ continue;
426
+ }
427
+
428
+ // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
429
+ char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
430
+ int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
431
+ if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
432
+ pos += (cpt == ' ');
433
+ while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
434
+ cpt2_type = _get_cpt_type(++pos);
435
+ cpt2 = _get_cpt(pos);
436
+ }
437
+ while (cpt2 == '\r' || cpt2 == '\n') {
438
+ cpt2 = _get_cpt(++pos);
439
+ }
440
+ _add_token(pos);
441
+ continue;
442
+ }
443
+
444
+ size_t num_whitespaces = 0;
445
+ size_t last_end_r_or_n = 0;
446
+ while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
447
+ char32_t cpt2 = _get_cpt(pos+num_whitespaces);
448
+ if (cpt2 == '\r' || cpt2 == '\n') {
449
+ last_end_r_or_n = pos + num_whitespaces + 1;
450
+ }
451
+ num_whitespaces++;
452
+ }
453
+
454
+ // regex: \s*[\r\n]+
455
+ if (last_end_r_or_n > 0) {
456
+ pos = last_end_r_or_n;
457
+ _add_token(pos);
458
+ continue;
459
+ }
460
+
461
+ // regex: \s+(?!\S)
462
+ if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
463
+ pos += num_whitespaces - 1;
464
+ _add_token(pos);
465
+ continue;
466
+ }
467
+
468
+ // regex: \s+
469
+ if (num_whitespaces > 0) {
470
+ pos += num_whitespaces;
471
+ _add_token(pos);
472
+ continue;
473
+ }
474
+
475
+ // no matches
476
+ _add_token(++pos);
477
+ }
478
+ }
479
+
480
+ return bpe_offsets;
481
+ }
482
+
483
+ // use std::wregex to split the text
484
+ static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
485
+ std::wregex expr(regex_expr);
486
+ std::vector<size_t> bpe_offsets; // store the offset of each word
487
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
488
+ size_t start = 0;
489
+ for (auto offset : offsets) {
490
+ std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
491
+ std::wcregex_iterator end;
492
+
493
+ int64_t start_idx = 0;
494
+ while (it != end) {
495
+ std::wcmatch match = *it;
496
+ if (match.position() > start_idx) {
497
+ bpe_offsets.emplace_back(match.position() - start_idx);
498
+ }
499
+ bpe_offsets.emplace_back(match.length());
500
+ start_idx = match.position() + match.length();
501
+ ++it;
502
+ }
503
+
504
+ if (start_idx < (int64_t) offset) {
505
+ bpe_offsets.emplace_back(offset - start_idx);
506
+ }
507
+ start += offset;
508
+ }
509
+
510
+ return bpe_offsets;
511
+ }
512
+
513
+ // use std::regex to split the text
514
+ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
515
+ std::regex expr(regex_expr);
516
+ std::vector<size_t> bpe_offsets; // store the offset of each word
517
+ bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
518
+ size_t start = 0;
519
+ for (auto offset : offsets) {
520
+ std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
521
+ std::cregex_iterator end;
522
+
523
+ int64_t start_idx = 0;
524
+ while (it != end) {
525
+ std::cmatch match = *it;
526
+ if (match.position() > start_idx) {
527
+ bpe_offsets.emplace_back(match.position() - start_idx);
528
+ }
529
+ bpe_offsets.emplace_back(match.length());
530
+ start_idx = match.position() + match.length();
531
+ ++it;
532
+ }
533
+
534
+ if (start_idx < (int64_t) offset) {
535
+ bpe_offsets.emplace_back(offset - start_idx);
536
+ }
537
+ start += offset;
538
+ }
539
+
540
+ return bpe_offsets;
541
+ }
542
+
543
+ static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
544
+ std::vector<size_t> bpe_offsets;
545
+
546
+ if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
547
+ bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
548
+ } else if (
549
+ regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
550
+ regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
551
+
552
+ bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
553
+ }
554
+
555
+ return bpe_offsets;
556
+ }
557
+
558
  //
559
  // interface
560
  //
561
 
562
  std::string unicode_cpt_to_utf8(uint32_t cp) {
563
  std::string result;
564
+
565
  if (/* 0x00 <= cp && */ cp <= 0x7f) {
566
  result.push_back(cp);
567
+ return result;
568
  }
569
+ if (0x80 <= cp && cp <= 0x7ff) {
570
  result.push_back(0xc0 | ((cp >> 6) & 0x1f));
571
  result.push_back(0x80 | (cp & 0x3f));
572
+ return result;
573
  }
574
+ if (0x800 <= cp && cp <= 0xffff) {
575
  result.push_back(0xe0 | ((cp >> 12) & 0x0f));
576
  result.push_back(0x80 | ((cp >> 6) & 0x3f));
577
  result.push_back(0x80 | (cp & 0x3f));
578
+ return result;
579
  }
580
+ if (0x10000 <= cp && cp <= 0x10ffff) {
581
  result.push_back(0xf0 | ((cp >> 18) & 0x07));
582
  result.push_back(0x80 | ((cp >> 12) & 0x3f));
583
  result.push_back(0x80 | ((cp >> 6) & 0x3f));
584
  result.push_back(0x80 | (cp & 0x3f));
585
+ return result;
586
  }
587
+
588
+ throw std::invalid_argument("invalid codepoint");
 
 
589
  }
590
 
591
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
 
625
  return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
626
  }
627
 
628
+ bool unicode_cpt_is_whitespace(uint32_t cp) {
629
+ static const std::unordered_set<uint32_t> is_whitespace = [] {
630
+ std::unordered_set<uint32_t> is_whitespace;
631
+ for (auto p : unicode_ranges_whitespace) {
632
+ for (auto i = p.first; i <= p.second; ++i) {
633
+ is_whitespace.insert(i);
634
+ }
635
+ }
636
+ return is_whitespace;
637
+ }();
638
+ return (bool)is_whitespace.count(cp);
639
+ }
640
+
641
  std::string unicode_byte_to_utf8(uint8_t byte) {
642
  static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
643
  return map.at(byte);
 
652
  auto it = unicode_map_lowercase.find(cp);
653
  return it == unicode_map_lowercase.end() ? cp : it->second;
654
  }
655
+
656
+ std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
657
+ // unicode categories
658
+ static const std::map<std::string, int> k_ucat_enum = {
659
+ { "\\p{N}", CODEPOINT_TYPE_NUMBER },
660
+ { "\\p{L}", CODEPOINT_TYPE_LETTER },
661
+ { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
662
+ };
663
+
664
+ static const std::map<int, int> k_ucat_cpt = {
665
+ { CODEPOINT_TYPE_NUMBER, 0xD1 },
666
+ { CODEPOINT_TYPE_LETTER, 0xD2 },
667
+ { CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
668
+ };
669
+
670
+ static const std::map<int, std::string> k_ucat_map = {
671
+ { CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9
672
+ { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
673
+ { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
674
+ };
675
+
676
+ // compute collapsed codepoints only if needed by at least one regex
677
+ bool need_collapse = false;
678
+ for (auto & regex_expr : regex_exprs) {
679
+ // search for unicode categories
680
+ for (const auto & ucat : k_ucat_enum) {
681
+ if (std::string::npos != regex_expr.find(ucat.first)) {
682
+ need_collapse = true;
683
+ break;
684
+ }
685
+ }
686
+ }
687
+
688
+ const auto cpts = unicode_cpts_from_utf8(text);
689
+
690
+ // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
691
+ // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
692
+ std::string text_collapsed;
693
+ if (need_collapse) {
694
+ // collapse all unicode categories
695
+ text_collapsed.resize(cpts.size());
696
+
697
+ for (size_t i = 0; i < cpts.size(); ++i) {
698
+ // keep single-byte codepoints as is
699
+ if (cpts[i] < 128) {
700
+ text_collapsed[i] = cpts[i];
701
+ continue;
702
+ }
703
+
704
+ const int cpt_type = unicode_cpt_type(cpts[i]);
705
+
706
+ if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
707
+ text_collapsed[i] = k_ucat_cpt.at(cpt_type);
708
+ } else {
709
+ text_collapsed[i] = (char) 0xD0; // fallback
710
+ }
711
+ }
712
+ }
713
+
714
+ std::vector<size_t> bpe_offsets = { cpts.size() };
715
+
716
+ for (auto & regex_expr : regex_exprs) {
717
+ // first, see if we have an efficient custom regex implementation
718
+ auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
719
+
720
+ if (!tmp.empty()) {
721
+ bpe_offsets = std::move(tmp);
722
+ continue;
723
+ }
724
+
725
+ // fallback to general-purpose std::regex / std::wregex
726
+ try {
727
+ // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
728
+ // with the corresponding collapsed representation
729
+ bool use_collapsed = false;
730
+ for (auto & ucat : k_ucat_enum) {
731
+ if (std::string::npos != regex_expr.find(ucat.first)) {
732
+ use_collapsed = true;
733
+ break;
734
+ }
735
+ }
736
+
737
+ if (use_collapsed) {
738
+ // sanity-check that the original regex does not contain any non-ASCII characters
739
+ const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
740
+ for (size_t i = 0; i < cpts_regex.size(); ++i) {
741
+ if (cpts_regex[i] >= 128) {
742
+ throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
743
+ }
744
+ }
745
+
746
+ // generate a collapsed representation of the regex
747
+ std::string regex_expr_collapsed;
748
+
749
+ // track if we are inside [], because nested [] are not allowed
750
+ bool inside = false;
751
+ for (size_t i = 0; i < regex_expr.size(); ++i) {
752
+ if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
753
+ regex_expr_collapsed += '[';
754
+ inside = true;
755
+ continue;
756
+ }
757
+
758
+ if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
759
+ regex_expr_collapsed += ']';
760
+ inside = false;
761
+ continue;
762
+ }
763
+
764
+ if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
765
+ regex_expr[i + 1] == 'p' &&
766
+ regex_expr[i + 2] == '{' &&
767
+ regex_expr[i + 4] == '}') {
768
+ const std::string pat = regex_expr.substr(i, 5);
769
+ if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
770
+ if (!inside) {
771
+ regex_expr_collapsed += '[';
772
+ }
773
+ regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
774
+ regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
775
+ if (!inside) {
776
+ regex_expr_collapsed += ']';
777
+ }
778
+ i += 4;
779
+ continue;
780
+ }
781
+ }
782
+
783
+ regex_expr_collapsed += regex_expr[i];
784
+ }
785
+
786
+ //printf("text_collapsed: %s\n", text_collapsed.c_str());
787
+ //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
788
+ bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
789
+ } else {
790
+ // no unicode category used, we can use std::wregex directly
791
+ const std::wstring wtext = unicode_wstring_from_utf8(text);
792
+ const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
793
+
794
+ //printf("text: %s\n", text.c_str());
795
+ //printf("regex_expr: %s\n", regex_expr.c_str());
796
+ bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
797
+ }
798
+ } catch (std::regex_error & e) {
799
+ fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
800
+ fprintf(stderr, "Regex error: %s\n", e.what());
801
+ throw std::runtime_error("Failed to process regex");
802
+ }
803
+ }
804
+
805
+ std::vector<std::string> bpe_words;
806
+ bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
807
+
808
+ size_t start = 0;
809
+ for (size_t & offset : bpe_offsets) {
810
+ bpe_words.emplace_back();
811
+ for (size_t i = start; i < start + offset; ++i) {
812
+ bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
813
+ }
814
+ start += offset;
815
+ }
816
+
817
+ return unicode_byte_encoding_process(bpe_words);
818
+ }
examples/talk-llama/unicode.h CHANGED
@@ -5,9 +5,9 @@
5
  #include <vector>
6
 
7
  #define CODEPOINT_TYPE_UNIDENTIFIED 0
8
- #define CODEPOINT_TYPE_DIGIT 1
9
  #define CODEPOINT_TYPE_LETTER 2
10
- #define CODEPOINT_TYPE_WHITESPACE 3
11
  #define CODEPOINT_TYPE_ACCENT_MARK 4
12
  #define CODEPOINT_TYPE_PUNCTUATION 5
13
  #define CODEPOINT_TYPE_SYMBOL 6
@@ -21,8 +21,11 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
21
  int unicode_cpt_type(uint32_t cp);
22
  int unicode_cpt_type(const std::string & utf8);
23
 
 
 
24
  std::string unicode_byte_to_utf8(uint8_t byte);
25
  uint8_t unicode_utf8_to_byte(const std::string & utf8);
26
 
27
- // simple tolower that only implements one-to-one mapping, not one-to-many
28
  char32_t unicode_tolower(char32_t cp);
 
 
 
5
  #include <vector>
6
 
7
  #define CODEPOINT_TYPE_UNIDENTIFIED 0
8
+ #define CODEPOINT_TYPE_NUMBER 1
9
  #define CODEPOINT_TYPE_LETTER 2
10
+ #define CODEPOINT_TYPE_SEPARATOR 3
11
  #define CODEPOINT_TYPE_ACCENT_MARK 4
12
  #define CODEPOINT_TYPE_PUNCTUATION 5
13
  #define CODEPOINT_TYPE_SYMBOL 6
 
21
  int unicode_cpt_type(uint32_t cp);
22
  int unicode_cpt_type(const std::string & utf8);
23
 
24
+ bool unicode_cpt_is_whitespace(uint32_t cp);
25
+
26
  std::string unicode_byte_to_utf8(uint8_t byte);
27
  uint8_t unicode_utf8_to_byte(const std::string & utf8);
28
 
 
29
  char32_t unicode_tolower(char32_t cp);
30
+
31
+ std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);