ggerganov Chris Raethke commited on
Commit
7006035
·
unverified ·
1 Parent(s): 6023f2d

sync : ggml (backend v2, k-quants, CUDA opts, Metal opts, etc.) (#1422)

Browse files

* sync : ggml (backend v2, k-quants, CUDA opts, Metal opts, etc.)

* metal : allow env metal variable to override resource path (#1415)

* Allow env variable to override resource path

* Update ggml-metal.m

---------

Co-authored-by: Georgi Gerganov <[email protected]>

* sync : restore common / main from `master`

* sync : restore whisper from `master`

* talk-llama : update to latest llama.cpp

* ruby : fix build

* ggml : fix 32-bit ARM build

* ggml : fix MIN / MAX macro collisions + update ios bindings

* ggml : fix ifdefs and MIN / MAX again

* exampels : fix Obj-C and Swift examples

* ggml : fix 32-bit ARM compatibility

* ggml : one more attempt to fix 32-bit ARM compat

* whisper : fix support for larger graphs

---------

Co-authored-by: Chris Raethke <[email protected]>

CMakeLists.txt CHANGED
@@ -464,6 +464,10 @@ add_library(${TARGET}
464
  ggml.c
465
  ggml-alloc.h
466
  ggml-alloc.c
 
 
 
 
467
  ${GGML_SOURCES_METAL}
468
  ${GGML_SOURCES_CUDA}
469
  ${GGML_SOURCES_OPENCL}
 
464
  ggml.c
465
  ggml-alloc.h
466
  ggml-alloc.c
467
+ ggml-backend.h
468
+ ggml-backend.c
469
+ ggml-quants.h
470
+ ggml-quants.c
471
  ${GGML_SOURCES_METAL}
472
  ${GGML_SOURCES_CUDA}
473
  ${GGML_SOURCES_OPENCL}
Makefile CHANGED
@@ -301,7 +301,13 @@ ggml.o: ggml.c ggml.h ggml-cuda.h
301
  ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
302
  $(CC) $(CFLAGS) -c $< -o $@
303
 
304
- WHISPER_OBJ += ggml-alloc.o
 
 
 
 
 
 
305
 
306
  whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
307
  $(CXX) $(CXXFLAGS) -c $< -o $@
 
301
  ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
302
  $(CC) $(CFLAGS) -c $< -o $@
303
 
304
+ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
305
+ $(CC) $(CFLAGS) -c $< -o $@
306
+
307
+ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
308
+ $(CC) $(CFLAGS) -c $< -o $@
309
+
310
+ WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
311
 
312
  whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
313
  $(CXX) $(CXXFLAGS) -c $< -o $@
bindings/ios CHANGED
@@ -1 +1 @@
1
- Subproject commit 22a9eef021afc67f2154bc9811ed620b26299d1b
 
1
+ Subproject commit 44b39fd4ec616a9ce66635e36045372d03dd45e0
bindings/ruby/ext/extconf.rb CHANGED
@@ -3,8 +3,14 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
3
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
4
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
5
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
 
6
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.h')} .")
7
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.c')} .")
 
 
 
 
 
8
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
9
 
10
 
 
3
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
4
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
5
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
6
+ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")
7
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.h')} .")
8
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.c')} .")
9
+ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-backend-impl.h')} .")
10
+ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-backend.h')} .")
11
+ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-backend.c')} .")
12
+ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-quants.h')} .")
13
+ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-quants.c')} .")
14
  system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
15
 
16
 
bindings/ruby/ext/ggml-backend-impl.h ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ggml-backend internal header
4
+
5
+ #include "ggml-backend.h"
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ //
12
+ // Backend buffer
13
+ //
14
+
15
+ typedef void * ggml_backend_buffer_context_t;
16
+
17
+ struct ggml_backend_buffer_i {
18
+ void (*free_buffer) (ggml_backend_buffer_t buffer);
19
+ void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
20
+ size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
21
+ void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
22
+ void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
23
+ };
24
+
25
+ struct ggml_backend_buffer {
26
+ struct ggml_backend_buffer_i iface;
27
+
28
+ ggml_backend_t backend;
29
+ ggml_backend_buffer_context_t context;
30
+
31
+ size_t size;
32
+ };
33
+
34
+ GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
35
+ struct ggml_backend * backend,
36
+ struct ggml_backend_buffer_i iface,
37
+ ggml_backend_buffer_context_t context,
38
+ size_t size);
39
+
40
+ //
41
+ // Backend
42
+ //
43
+
44
+ typedef void * ggml_backend_context_t;
45
+
46
+ struct ggml_backend_i {
47
+ const char * (*get_name)(ggml_backend_t backend);
48
+
49
+ void (*free)(ggml_backend_t backend);
50
+
51
+ // buffer allocation
52
+ ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
53
+
54
+ // get buffer alignment
55
+ size_t (*get_alignment)(ggml_backend_t backend);
56
+
57
+ // tensor data access
58
+ // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
59
+ void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
60
+ void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
61
+ void (*synchronize) (ggml_backend_t backend);
62
+
63
+ // (optional) copy tensor between different backends, allow for single-copy tranfers
64
+ void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
65
+ void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
66
+
67
+ // compute graph with a plan
68
+ ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
69
+ void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
70
+ void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
71
+
72
+ // compute graph without a plan
73
+ void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
74
+
75
+ // check if the backend supports an operation
76
+ bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
77
+ };
78
+
79
+ struct ggml_backend {
80
+ struct ggml_backend_i iface;
81
+
82
+ ggml_backend_context_t context;
83
+ };
84
+
85
+ #ifdef __cplusplus
86
+ }
87
+ #endif
bindings/ruby/ext/ggml-backend.c ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-backend-impl.h"
2
+ #include "ggml-alloc.h"
3
+ #include "ggml-impl.h"
4
+
5
+ #include <assert.h>
6
+ #include <limits.h>
7
+ #include <stdarg.h>
8
+ #include <stdio.h>
9
+ #include <stdlib.h>
10
+ #include <string.h>
11
+
12
+ #define UNUSED GGML_UNUSED
13
+
14
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
15
+
16
+ // backend buffer
17
+
18
+ ggml_backend_buffer_t ggml_backend_buffer_init(
19
+ struct ggml_backend * backend,
20
+ struct ggml_backend_buffer_i iface,
21
+ ggml_backend_buffer_context_t context,
22
+ size_t size) {
23
+ ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
24
+
25
+ GGML_ASSERT(iface.get_base != NULL);
26
+
27
+ (*buffer) = (struct ggml_backend_buffer) {
28
+ /* .interface = */ iface,
29
+ /* .backend = */ backend,
30
+ /* .context = */ context,
31
+ /* .size = */ size,
32
+ };
33
+
34
+ return buffer;
35
+ }
36
+
37
+ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
38
+ if (buffer == NULL) {
39
+ return;
40
+ }
41
+
42
+ if (buffer->iface.free_buffer != NULL) {
43
+ buffer->iface.free_buffer(buffer);
44
+ }
45
+ free(buffer);
46
+ }
47
+
48
+ size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
49
+ return ggml_backend_get_alignment(buffer->backend);
50
+ }
51
+
52
+ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
53
+ return buffer->size;
54
+ }
55
+
56
+ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
57
+ void * base = buffer->iface.get_base(buffer);
58
+
59
+ GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
60
+
61
+ return base;
62
+ }
63
+
64
+ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
65
+ // get_alloc_size is optional, defaults to ggml_nbytes
66
+ if (buffer->iface.get_alloc_size) {
67
+ return buffer->iface.get_alloc_size(buffer, tensor);
68
+ }
69
+ return ggml_nbytes(tensor);
70
+ }
71
+
72
+ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
73
+ // init_tensor is optional
74
+ if (buffer->iface.init_tensor) {
75
+ buffer->iface.init_tensor(buffer, tensor);
76
+ }
77
+ }
78
+
79
+ void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
80
+ // free_tensor is optional
81
+ if (buffer->iface.free_tensor) {
82
+ buffer->iface.free_tensor(buffer, tensor);
83
+ }
84
+ }
85
+
86
+ // backend
87
+
88
+ ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
89
+ return tensor->buffer ? tensor->buffer->backend : NULL;
90
+ }
91
+
92
+ const char * ggml_backend_name(ggml_backend_t backend) {
93
+ if (backend == NULL) {
94
+ return "NULL";
95
+ }
96
+ return backend->iface.get_name(backend);
97
+ }
98
+
99
+ void ggml_backend_free(ggml_backend_t backend) {
100
+ if (backend == NULL) {
101
+ return;
102
+ }
103
+
104
+ backend->iface.free(backend);
105
+ }
106
+
107
+ ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
108
+ return backend->iface.alloc_buffer(backend, size);
109
+ }
110
+
111
+ size_t ggml_backend_get_alignment(ggml_backend_t backend) {
112
+ return backend->iface.get_alignment(backend);
113
+ }
114
+
115
+ void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
116
+ ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
117
+ }
118
+
119
+ void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
120
+ ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
121
+ }
122
+
123
+ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
124
+ ggml_backend_t backend = ggml_get_backend(tensor);
125
+
126
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
127
+ GGML_ASSERT(backend != NULL && "tensor backend not set");
128
+
129
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
130
+ backend->iface.synchronize(backend);
131
+ }
132
+
133
+ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
134
+ ggml_backend_t backend = ggml_get_backend(tensor);
135
+
136
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
137
+ GGML_ASSERT(backend != NULL && "tensor backend not set");
138
+
139
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
140
+ backend->iface.synchronize(backend);
141
+ }
142
+
143
+ void ggml_backend_synchronize(ggml_backend_t backend) {
144
+ backend->iface.synchronize(backend);
145
+ }
146
+
147
+ ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
148
+ return backend->iface.graph_plan_create(backend, cgraph);
149
+ }
150
+
151
+ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
152
+ backend->iface.graph_plan_free(backend, plan);
153
+ }
154
+
155
+ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
156
+ backend->iface.graph_plan_compute(backend, plan);
157
+ }
158
+
159
+ void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
160
+ backend->iface.graph_compute(backend, cgraph);
161
+ }
162
+
163
+ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
164
+ return backend->iface.supports_op(backend, op);
165
+ }
166
+
167
+ // backend copy
168
+
169
+ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
170
+ if (a->type != b->type) {
171
+ return false;
172
+ }
173
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
174
+ if (a->ne[i] != b->ne[i]) {
175
+ return false;
176
+ }
177
+ if (a->nb[i] != b->nb[i]) {
178
+ return false;
179
+ }
180
+ }
181
+ return true;
182
+ }
183
+
184
+ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
185
+ //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]);
186
+ //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
187
+ GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
188
+
189
+ // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
190
+
191
+ if (src == dst) {
192
+ return;
193
+ }
194
+
195
+ // TODO: allow backends to support copy to/from same backend
196
+
197
+ if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
198
+ ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst);
199
+ } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
200
+ ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
201
+ } else {
202
+ // shouldn't be hit when copying from/to CPU
203
+ #ifndef NDEBUG
204
+ fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
205
+ #endif
206
+ size_t nbytes = ggml_nbytes(src);
207
+ void * data = malloc(nbytes);
208
+ ggml_backend_tensor_get(src, data, 0, nbytes);
209
+ ggml_backend_tensor_set(dst, data, 0, nbytes);
210
+ free(data);
211
+ }
212
+ }
213
+
214
+ // backend CPU
215
+
216
+ struct ggml_backend_cpu_context {
217
+ int n_threads;
218
+ void * work_data;
219
+ size_t work_size;
220
+ };
221
+
222
+ static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
223
+ return "CPU";
224
+
225
+ UNUSED(backend);
226
+ }
227
+
228
+ static void ggml_backend_cpu_free(ggml_backend_t backend) {
229
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
230
+ free(cpu_ctx->work_data);
231
+ free(cpu_ctx);
232
+ free(backend);
233
+ }
234
+
235
+ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
236
+ return (void *)buffer->context;
237
+ }
238
+
239
+ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
240
+ free(buffer->context);
241
+ UNUSED(buffer);
242
+ }
243
+
244
+ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
245
+ /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
246
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
247
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
248
+ /* .init_tensor = */ NULL, // no initialization required
249
+ /* .free_tensor = */ NULL, // no cleanup required
250
+ };
251
+
252
+ // for buffers from ptr, free is not called
253
+ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
254
+ /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
255
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
256
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
257
+ /* .init_tensor = */ NULL,
258
+ /* .free_tensor = */ NULL,
259
+ };
260
+
261
+ static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
262
+
263
+ static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) {
264
+ size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
265
+ void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
266
+
267
+ GGML_ASSERT(data != NULL && "failed to allocate buffer");
268
+
269
+ return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
270
+ }
271
+
272
+ static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
273
+ return TENSOR_ALIGNMENT;
274
+ UNUSED(backend);
275
+ }
276
+
277
+ static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
278
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
279
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
280
+
281
+ memcpy((char *)tensor->data + offset, data, size);
282
+
283
+ UNUSED(backend);
284
+ }
285
+
286
+ static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
287
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
288
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
289
+
290
+ memcpy(data, (const char *)tensor->data + offset, size);
291
+
292
+ UNUSED(backend);
293
+ }
294
+
295
+ static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
296
+ UNUSED(backend);
297
+ }
298
+
299
+ static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
300
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
301
+
302
+ UNUSED(backend);
303
+ }
304
+
305
+ static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
306
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
307
+
308
+ UNUSED(backend);
309
+ }
310
+
311
+ struct ggml_backend_plan_cpu {
312
+ struct ggml_cplan cplan;
313
+ struct ggml_cgraph cgraph;
314
+ };
315
+
316
+ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
317
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
318
+
319
+ struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
320
+
321
+ cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
322
+ cpu_plan->cgraph = *cgraph;
323
+
324
+ if (cpu_plan->cplan.work_size > 0) {
325
+ cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
326
+ }
327
+
328
+ return cpu_plan;
329
+ }
330
+
331
+ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
332
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
333
+
334
+ free(cpu_plan->cplan.work_data);
335
+ free(cpu_plan);
336
+
337
+ UNUSED(backend);
338
+ }
339
+
340
+ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
341
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
342
+
343
+ ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
344
+
345
+ UNUSED(backend);
346
+ }
347
+
348
+ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
349
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
350
+
351
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
352
+
353
+ if (cpu_ctx->work_size < cplan.work_size) {
354
+ // TODO: may be faster to free and use malloc to avoid the copy
355
+ cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
356
+ cpu_ctx->work_size = cplan.work_size;
357
+ }
358
+
359
+ cplan.work_data = cpu_ctx->work_data;
360
+
361
+ ggml_graph_compute(cgraph, &cplan);
362
+ }
363
+
364
+ static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
365
+ return true;
366
+ UNUSED(backend);
367
+ UNUSED(op);
368
+ }
369
+
370
+ static struct ggml_backend_i cpu_backend_i = {
371
+ /* .get_name = */ ggml_backend_cpu_name,
372
+ /* .free = */ ggml_backend_cpu_free,
373
+ /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer,
374
+ /* .get_alignment = */ ggml_backend_cpu_get_alignment,
375
+ /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async,
376
+ /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async,
377
+ /* .synchronize = */ ggml_backend_cpu_synchronize,
378
+ /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from,
379
+ /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to,
380
+ /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
381
+ /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
382
+ /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
383
+ /* .graph_compute = */ ggml_backend_cpu_graph_compute,
384
+ /* .supports_op = */ ggml_backend_cpu_supports_op,
385
+ };
386
+
387
+ ggml_backend_t ggml_backend_cpu_init(void) {
388
+ struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
389
+
390
+ ctx->n_threads = GGML_DEFAULT_N_THREADS;
391
+ ctx->work_data = NULL;
392
+ ctx->work_size = 0;
393
+
394
+ ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
395
+
396
+ *cpu_backend = (struct ggml_backend) {
397
+ /* .interface = */ cpu_backend_i,
398
+ /* .context = */ ctx
399
+ };
400
+ return cpu_backend;
401
+ }
402
+
403
+ bool ggml_backend_is_cpu(ggml_backend_t backend) {
404
+ return backend->iface.get_name == ggml_backend_cpu_name;
405
+ }
406
+
407
+ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
408
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
409
+
410
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
411
+ ctx->n_threads = n_threads;
412
+ }
413
+
414
+ ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
415
+ return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
416
+ }
417
+
418
+ // scheduler
419
+
420
+ #define GGML_MAX_BACKENDS 4
421
+ #define GGML_MAX_SPLITS 256
422
+ #define GGML_MAX_SPLIT_INPUTS 16
423
+
424
+ struct ggml_backend_sched_split {
425
+ ggml_tallocr_t tallocr;
426
+ int i_start;
427
+ int i_end;
428
+ struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
429
+ int n_inputs;
430
+ struct ggml_cgraph * graph;
431
+ };
432
+
433
+ struct ggml_backend_sched {
434
+ int n_backends;
435
+ ggml_backend_t backends[GGML_MAX_BACKENDS];
436
+ ggml_tallocr_t tallocs[GGML_MAX_BACKENDS];
437
+
438
+ ggml_gallocr_t galloc;
439
+
440
+ struct ggml_hash_set hash_set;
441
+ ggml_tallocr_t * node_talloc; // [hash_set.size]
442
+ struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
443
+
444
+ struct ggml_cgraph * graph;
445
+ struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
446
+ int n_splits;
447
+
448
+ struct ggml_context * ctx;
449
+
450
+ // align context_buffer to GGML_MEM_ALIGN
451
+ #ifdef _MSC_VER
452
+ __declspec(align(GGML_MEM_ALIGN))
453
+ #else
454
+ __attribute__((aligned(GGML_MEM_ALIGN)))
455
+ #endif
456
+ char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
457
+ };
458
+
459
+ #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
460
+ #define node_allocr(node) sched->node_talloc[hash_id(node)]
461
+
462
+ static bool ggml_is_view_op(enum ggml_op op) {
463
+ return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
464
+ }
465
+
466
+ // returns the priority of the backend, lower is better
467
+ static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
468
+ for (int i = 0; i < sched->n_backends; i++) {
469
+ if (sched->backends[i] == backend) {
470
+ return i;
471
+ }
472
+ }
473
+ return INT_MAX;
474
+ }
475
+
476
+ static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
477
+ for (int i = 0; i < sched->n_backends; i++) {
478
+ if (sched->tallocs[i] == allocr) {
479
+ return i;
480
+ }
481
+ }
482
+ return INT_MAX;
483
+ }
484
+
485
+ // returns the backend that should be used for the node based on the current locations
486
+ char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
487
+ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
488
+ // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
489
+ // ie. kv cache updates
490
+ // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
491
+ // dst
492
+ ggml_backend_t cur_backend = ggml_get_backend(node);
493
+ if (cur_backend != NULL) {
494
+ sprintf(causes[hash_id(node)], "1.dst");
495
+ return cur_backend;
496
+ }
497
+
498
+ // view_src
499
+ if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
500
+ sprintf(causes[hash_id(node)], "1.vsrc");
501
+ return ggml_get_backend(node->view_src);
502
+ }
503
+
504
+ // src
505
+ int cur_prio = INT_MAX;
506
+ size_t cur_size = 0;
507
+
508
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
509
+ const struct ggml_tensor * src = node->src[i];
510
+ if (src == NULL) {
511
+ break;
512
+ }
513
+ ggml_backend_t src_backend = ggml_get_backend(src);
514
+ if (src_backend != NULL) {
515
+ int src_prio = sched_backend_prio(sched, src_backend);
516
+ size_t src_size = ggml_nbytes(src);
517
+ if (src_prio < cur_prio && src_size >= cur_size) {
518
+ cur_prio = src_prio;
519
+ cur_size = src_size;
520
+ cur_backend = src_backend;
521
+ sprintf(causes[hash_id(node)], "1.src%d", i);
522
+ }
523
+ }
524
+ }
525
+ return cur_backend;
526
+ }
527
+
528
+ static char * fmt_size(size_t size) {
529
+ static char buffer[128];
530
+ if (size >= 1024*1024) {
531
+ sprintf(buffer, "%zuM", size/1024/1024);
532
+ } else {
533
+ sprintf(buffer, "%zuK", size/1024);
534
+ }
535
+ return buffer;
536
+ }
537
+
538
+ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
539
+ int cur_split = 0;
540
+ for (int i = 0; i < graph->n_nodes; i++) {
541
+ if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
542
+ ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
543
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
544
+ for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
545
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
546
+ }
547
+ fprintf(stderr, "\n");
548
+ cur_split++;
549
+ }
550
+ struct ggml_tensor * node = graph->nodes[i];
551
+ if (ggml_is_view_op(node->op)) {
552
+ continue;
553
+ }
554
+ ggml_tallocr_t node_allocr = node_allocr(node);
555
+ ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
556
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
557
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
558
+ struct ggml_tensor * src = node->src[j];
559
+ if (src == NULL) {
560
+ break;
561
+ }
562
+ ggml_tallocr_t src_allocr = node_allocr(src);
563
+ ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
564
+ fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
565
+ }
566
+ fprintf(stderr, "\n");
567
+ }
568
+ }
569
+
570
+ // creates a copy of the tensor with the same memory layout
571
+ static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
572
+ struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
573
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
574
+ dup->nb[i] = tensor->nb[i];
575
+ }
576
+ return dup;
577
+ }
578
+
579
+ // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
580
+ // TODO: merge passes
581
+ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
582
+ // reset state
583
+ size_t hash_size = sched->hash_set.size;
584
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
585
+ memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
586
+ memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
587
+ sched->n_splits = 0;
588
+
589
+ struct ggml_init_params params = {
590
+ /*.mem_size = */ sizeof(sched->context_buffer),
591
+ /*.mem_buffer = */ sched->context_buffer,
592
+ /*.no_alloc = */ true
593
+ };
594
+
595
+ if (sched->ctx != NULL) {
596
+ ggml_free(sched->ctx);
597
+ }
598
+
599
+ sched->ctx = ggml_init(params);
600
+
601
+ // pass 1: assign backends to ops with allocated inputs
602
+ for (int i = 0; i < graph->n_leafs; i++) {
603
+ struct ggml_tensor * leaf = graph->leafs[i];
604
+ if (node_allocr(leaf) != NULL) {
605
+ // do not overwrite user assignments
606
+ continue;
607
+ }
608
+ ggml_backend_t leaf_backend = ggml_get_backend(leaf);
609
+ if (leaf_backend == NULL && leaf->view_src != NULL) {
610
+ leaf_backend = ggml_get_backend(leaf->view_src);
611
+ }
612
+ if (leaf_backend != NULL) {
613
+ node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
614
+ }
615
+ }
616
+
617
+ for (int i = 0; i < graph->n_nodes; i++) {
618
+ struct ggml_tensor * node = graph->nodes[i];
619
+ if (node_allocr(node) != NULL) {
620
+ // do not overwrite user assignments
621
+ continue;
622
+ }
623
+ ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
624
+ if (node_backend != NULL) {
625
+ node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
626
+ }
627
+ }
628
+ //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
629
+
630
+ // pass 2: assign backends to ops from current assignments
631
+ // TODO:
632
+ // - reuse sched_backend_from_cur
633
+ for (int i = 0; i < graph->n_nodes; i++) {
634
+ struct ggml_tensor * node = graph->nodes[i];
635
+ ggml_tallocr_t node_allocr = node_allocr(node);
636
+ if (node_allocr == NULL) {
637
+ int cur_prio = INT_MAX;
638
+ size_t cur_size = 0;
639
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
640
+ struct ggml_tensor * src = node->src[j];
641
+ if (src == NULL) {
642
+ break;
643
+ }
644
+ ggml_tallocr_t src_allocr = node_allocr(src);
645
+ if (src_allocr != NULL) {
646
+ int src_prio = sched_allocr_prio(sched, src_allocr);
647
+ size_t src_size = ggml_nbytes(src);
648
+ if (src_prio < cur_prio && src_size >= cur_size) {
649
+ cur_prio = src_prio;
650
+ cur_size = src_size;
651
+ node_allocr = src_allocr;
652
+ sprintf(causes[hash_id(node)], "2.src%d", j);
653
+ }
654
+ }
655
+ }
656
+ if (node_allocr != NULL) {
657
+ node_allocr(node) = node_allocr;
658
+ }
659
+ }
660
+ }
661
+ //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
662
+
663
+ // pass 3: assign backends to remaining src from dst (should only be leafs)
664
+ for (int i = 0; i < graph->n_nodes; i++) {
665
+ struct ggml_tensor * node = graph->nodes[i];
666
+ ggml_tallocr_t node_allocr = node_allocr(node);
667
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
668
+ struct ggml_tensor * src = node->src[j];
669
+ if (src == NULL) {
670
+ break;
671
+ }
672
+ ggml_tallocr_t src_allocr = node_allocr(src);
673
+ if (src_allocr == NULL) {
674
+ node_allocr(src) = node_allocr;
675
+ }
676
+ }
677
+ }
678
+ //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
679
+
680
+ // pass 4: split graph, find tensors that need to be copied
681
+ // TODO:
682
+ // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
683
+ // find first backend
684
+ int cur_split = 0;
685
+ for (int i = 0; i < graph->n_nodes; i++) {
686
+ struct ggml_tensor * node = graph->nodes[i];
687
+ if (node->view_src == NULL) {
688
+ sched->splits[0].tallocr = node_allocr(node);
689
+ break;
690
+ }
691
+ }
692
+ sched->splits[0].i_start = 0;
693
+ sched->splits[0].n_inputs = 0;
694
+ memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
695
+ ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
696
+ size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
697
+ for (int i = 0; i < graph->n_nodes; i++) {
698
+ struct ggml_tensor * node = graph->nodes[i];
699
+
700
+ if (ggml_is_view_op(node->op)) {
701
+ continue;
702
+ }
703
+
704
+ ggml_tallocr_t node_allocr = node_allocr(node);
705
+
706
+ if (node_allocr != cur_allocr) {
707
+ sched->splits[cur_split].i_end = i;
708
+ cur_split++;
709
+ GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
710
+ sched->splits[cur_split].tallocr = node_allocr;
711
+ sched->splits[cur_split].i_start = i;
712
+ sched->splits[cur_split].n_inputs = 0;
713
+ memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
714
+ cur_allocr = node_allocr;
715
+ cur_backend_id = sched_allocr_prio(sched, cur_allocr);
716
+ }
717
+
718
+ // find inputs that are not on the same backend
719
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
720
+ struct ggml_tensor * src = node->src[j];
721
+ if (src == NULL) {
722
+ break;
723
+ }
724
+ ggml_tallocr_t src_allocr = node_allocr(src);
725
+ if (src_allocr != node_allocr) {
726
+ int n_inputs = sched->splits[cur_split].n_inputs++;
727
+ GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
728
+ sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
729
+
730
+ // create copies
731
+ size_t id = hash_id(src);
732
+ if (sched->node_copies[id][cur_backend_id] == NULL) {
733
+ struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
734
+ sched->node_copies[id][cur_backend_id] = tensor_copy;
735
+ node_allocr(tensor_copy) = cur_allocr;
736
+ ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
737
+ ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
738
+ }
739
+ node->src[j] = sched->node_copies[id][cur_backend_id];
740
+ }
741
+ }
742
+ }
743
+ sched->splits[cur_split].i_end = graph->n_nodes;
744
+ sched->n_splits = cur_split + 1;
745
+
746
+ //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
747
+
748
+ #if 1
749
+ // sanity check: all sources should have the same backend as the node
750
+ for (int i = 0; i < graph->n_nodes; i++) {
751
+ struct ggml_tensor * node = graph->nodes[i];
752
+ ggml_tallocr_t node_allocr = node_allocr(node);
753
+ if (node_allocr == NULL) {
754
+ fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
755
+ }
756
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
757
+ struct ggml_tensor * src = node->src[j];
758
+ if (src == NULL) {
759
+ break;
760
+ }
761
+ ggml_tallocr_t src_allocr = node_allocr(src);
762
+ if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
763
+ fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
764
+ node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
765
+ j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
766
+ }
767
+ }
768
+ }
769
+ #endif
770
+
771
+ // create copies of the graph for each split
772
+ // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
773
+ struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
774
+ for (int i = 0; i < sched->n_splits; i++) {
775
+ struct ggml_backend_sched_split * split = &sched->splits[i];
776
+ split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
777
+
778
+ // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
779
+ for (int j = 0; j < split->n_inputs; j++) {
780
+ struct ggml_tensor * input = split->inputs[j];
781
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
782
+ input_cpy->src[0] = input;
783
+ graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
784
+ }
785
+
786
+ for (int j = split->i_start; j < split->i_end; j++) {
787
+ graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
788
+ }
789
+ }
790
+ sched->graph = graph_copy;
791
+ }
792
+
793
+ static void sched_alloc_splits(ggml_backend_sched_t sched) {
794
+ ggml_gallocr_alloc_graph_n(
795
+ sched->galloc,
796
+ sched->graph,
797
+ sched->hash_set,
798
+ sched->node_talloc);
799
+ }
800
+
801
+ static void sched_compute_splits(ggml_backend_sched_t sched) {
802
+ uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
803
+ uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
804
+
805
+ struct ggml_backend_sched_split * splits = sched->splits;
806
+
807
+ for (int i = 0; i < sched->n_splits; i++) {
808
+ struct ggml_backend_sched_split * split = &splits[i];
809
+ ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
810
+ int split_backend_id = sched_backend_prio(sched, split_backend);
811
+
812
+ // copy the input tensors to the split backend
813
+ uint64_t copy_start_us = ggml_time_us();
814
+ for (int j = 0; j < split->n_inputs; j++) {
815
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
816
+ if (split->inputs[j]->buffer == NULL) {
817
+ if (split->inputs[j]->view_src == NULL) {
818
+ fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
819
+ exit(1);
820
+ }
821
+ struct ggml_tensor * view = split->inputs[j];
822
+ view->backend = view->view_src->backend;
823
+ view->buffer = view->view_src->buffer;
824
+ view->data = (char *)view->view_src->data + view->view_offs;
825
+ ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
826
+ }
827
+ if (input_cpy->buffer == NULL) {
828
+ fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
829
+ exit(1);
830
+ }
831
+ GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
832
+ GGML_ASSERT(input_cpy->buffer->backend == split_backend);
833
+ ggml_backend_tensor_copy(split->inputs[j], input_cpy);
834
+ }
835
+ // ggml_backend_synchronize(split_backend);
836
+ int64_t copy_end_us = ggml_time_us();
837
+ copy_us[split_backend_id] += copy_end_us - copy_start_us;
838
+
839
+ #if 0
840
+ char split_filename[GGML_MAX_NAME];
841
+ snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
842
+ ggml_graph_dump_dot(split->graph, NULL, split_filename);
843
+ #endif
844
+
845
+ uint64_t compute_start_us = ggml_time_us();
846
+ ggml_backend_graph_compute(split_backend, split->graph);
847
+ // ggml_backend_synchronize(split_backend);
848
+ uint64_t compute_end_us = ggml_time_us();
849
+ compute_us[split_backend_id] += compute_end_us - compute_start_us;
850
+ }
851
+
852
+ #if 0
853
+ // per-backend timings
854
+ fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
855
+ for (int i = 0; i < sched->n_backends; i++) {
856
+ if (copy_us[i] > 0 || compute_us[i] > 0) {
857
+ fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
858
+ }
859
+ }
860
+ #endif
861
+ }
862
+
863
+ static void sched_reset(ggml_backend_sched_t sched) {
864
+ for (int i = 0; i < sched->n_backends; i++) {
865
+ ggml_tallocr_reset(sched->tallocs[i]);
866
+ }
867
+ }
868
+
869
+ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
870
+ GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
871
+
872
+ struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
873
+ memset(sched, 0, sizeof(struct ggml_backend_sched));
874
+
875
+ fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
876
+
877
+ sched->n_backends = n_backends;
878
+ for (int i = 0; i < n_backends; i++) {
879
+ sched->backends[i] = backends[i];
880
+ }
881
+
882
+ sched->galloc = ggml_gallocr_new();
883
+
884
+ // init measure allocs for each backend
885
+ for (int i = 0; i < n_backends; i++) {
886
+ sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]);
887
+ }
888
+
889
+ return sched;
890
+ }
891
+
892
+ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
893
+ if (sched == NULL) {
894
+ return;
895
+ }
896
+ for (int i = 0; i < sched->n_backends; i++) {
897
+ ggml_tallocr_free(sched->tallocs[i]);
898
+ }
899
+ ggml_gallocr_free(sched->galloc);
900
+ free(sched->hash_set.keys);
901
+ free(sched->node_talloc);
902
+ free(sched->node_copies);
903
+ free(sched);
904
+ }
905
+
906
+ void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
907
+ // initialize hash tables
908
+ size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
909
+ sched->hash_set.size = hash_size;
910
+ sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
911
+ sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
912
+ sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
913
+
914
+ sched_split_graph(sched, measure_graph);
915
+ sched_alloc_splits(sched);
916
+
917
+ // allocate buffers and reset allocators
918
+ for (int i = 0; i < sched->n_backends; i++) {
919
+ size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
920
+ ggml_tallocr_free(sched->tallocs[i]);
921
+ sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size);
922
+ }
923
+
924
+ sched_reset(sched);
925
+ }
926
+
927
+ void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
928
+ GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
929
+
930
+ sched_split_graph(sched, graph);
931
+ sched_alloc_splits(sched);
932
+ sched_compute_splits(sched);
933
+ sched_reset(sched);
934
+ }
935
+
936
+ ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
937
+ int backend_index = sched_backend_prio(sched, backend);
938
+ return sched->tallocs[backend_index];
939
+ }
940
+
941
+ ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
942
+ int backend_index = sched_backend_prio(sched, backend);
943
+ return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
944
+ }
945
+
946
+ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
947
+ int backend_index = sched_backend_prio(sched, backend);
948
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
949
+ node_allocr(node) = sched->tallocs[backend_index];
950
+ }
bindings/ruby/ext/ggml-backend.h ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-alloc.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ //
11
+ // Backend buffer
12
+ //
13
+
14
+ struct ggml_backend_buffer;
15
+ typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
16
+
17
+ // backend buffer functions
18
+ GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
19
+ GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
20
+ GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
21
+ GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
22
+ GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
23
+ GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
24
+ GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
25
+
26
+ //
27
+ // Backend
28
+ //
29
+
30
+ struct ggml_backend;
31
+ typedef struct ggml_backend * ggml_backend_t;
32
+ typedef void * ggml_backend_graph_plan_t;
33
+
34
+ GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
35
+
36
+ GGML_API const char * ggml_backend_name(ggml_backend_t backend);
37
+ GGML_API void ggml_backend_free(ggml_backend_t backend);
38
+
39
+ GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
40
+
41
+ GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
42
+
43
+ GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
44
+ GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
45
+
46
+ GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
47
+ GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
48
+
49
+ GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
50
+
51
+ GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
52
+
53
+ GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
54
+ GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
55
+ GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
56
+ GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
57
+
58
+ // tensor copy between different backends
59
+ GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
60
+
61
+ //
62
+ // CPU backend
63
+ //
64
+
65
+ GGML_API ggml_backend_t ggml_backend_cpu_init(void);
66
+
67
+ GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
68
+ GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
69
+
70
+ // Create a backend buffer from an existing pointer
71
+ GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
72
+
73
+
74
+ //
75
+ // Backend scheduler
76
+ //
77
+
78
+ // The backend scheduler allows for multiple backends to be used together
79
+ // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
80
+ // The backends are selected based on:
81
+ // - the backend that supports the operation
82
+ // - the location of the pre-allocated tensors (e.g. the weights)
83
+ /*
84
+ Example usage:
85
+
86
+ sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
87
+ // sched is initialized with measure allocators and cannot be used until allocated with a measure graph
88
+
89
+ // initialize buffers from a measure graph
90
+ measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
91
+
92
+ // in build_graph:
93
+ build_graph(...) {
94
+ // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
95
+ alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
96
+ ggml_allocr_alloc(alloc_cpu, tensor);
97
+
98
+ // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
99
+ struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
100
+ ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
101
+ }
102
+
103
+ // allocate backend buffers from measure graph
104
+ ggml_backend_sched_init_measure(sched, measure_graph);
105
+
106
+ // the scheduler is now ready to compute graphs
107
+
108
+ // compute
109
+ graph = build_graph(sched);
110
+ ggml_backend_sched_graph_compute(sched, graph);
111
+ */
112
+
113
+ struct ggml_backend_sched;
114
+ typedef struct ggml_backend_sched * ggml_backend_sched_t;
115
+
116
+ // Initialize a backend scheduler
117
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
118
+
119
+ GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
120
+
121
+ // Initialize backend buffers from a measure graph
122
+ GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
123
+
124
+ GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
125
+ GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
126
+
127
+ GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
128
+
129
+ // Allocate a graph on the backend scheduler
130
+ GGML_API void ggml_backend_sched_graph_compute(
131
+ ggml_backend_sched_t sched,
132
+ struct ggml_cgraph * graph);
133
+
134
+ #ifdef __cplusplus
135
+ }
136
+ #endif
bindings/ruby/ext/ggml-impl.h ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+
5
+ // GGML internal header
6
+
7
+ #include <assert.h>
8
+ #include <stddef.h>
9
+ #include <stdbool.h>
10
+ #include <string.h> // memcpy
11
+ #include <math.h> // fabsf
12
+
13
+ #ifdef __cplusplus
14
+ extern "C" {
15
+ #endif
16
+
17
+ // static_assert should be a #define, but if it's not,
18
+ // fall back to the _Static_assert C11 keyword.
19
+ // if C99 - static_assert is noop
20
+ // ref: https://stackoverflow.com/a/53923785/4039976
21
+ #ifndef static_assert
22
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
23
+ #define static_assert(cond, msg) _Static_assert(cond, msg)
24
+ #else
25
+ #define static_assert(cond, msg) struct global_scope_noop_trick
26
+ #endif
27
+ #endif
28
+
29
+ // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
30
+ #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
31
+ #ifndef __FMA__
32
+ #define __FMA__
33
+ #endif
34
+ #ifndef __F16C__
35
+ #define __F16C__
36
+ #endif
37
+ #ifndef __SSE3__
38
+ #define __SSE3__
39
+ #endif
40
+ #endif
41
+
42
+ #undef MIN
43
+ #undef MAX
44
+
45
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
46
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
47
+
48
+ // 16-bit float
49
+ // on Arm, we use __fp16
50
+ // on x86, we use uint16_t
51
+ #if defined(__ARM_NEON) && !defined(_MSC_VER)
52
+
53
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
54
+ //
55
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
56
+ //
57
+ #include <arm_neon.h>
58
+
59
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
60
+ #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
61
+
62
+ #define GGML_FP16_TO_FP32(x) ((float) (x))
63
+ #define GGML_FP32_TO_FP16(x) (x)
64
+
65
+ #else
66
+
67
+ #ifdef __wasm_simd128__
68
+ #include <wasm_simd128.h>
69
+ #else
70
+ #ifdef __POWER9_VECTOR__
71
+ #include <altivec.h>
72
+ #undef bool
73
+ #define bool _Bool
74
+ #else
75
+ #if defined(_MSC_VER) || defined(__MINGW32__)
76
+ #include <intrin.h>
77
+ #else
78
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
79
+ #if !defined(__riscv)
80
+ #include <immintrin.h>
81
+ #endif
82
+ #endif
83
+ #endif
84
+ #endif
85
+ #endif
86
+
87
+ #ifdef __riscv_v_intrinsic
88
+ #include <riscv_vector.h>
89
+ #endif
90
+
91
+ #ifdef __F16C__
92
+
93
+ #ifdef _MSC_VER
94
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
95
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
96
+ #else
97
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
98
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
99
+ #endif
100
+
101
+ #elif defined(__POWER9_VECTOR__)
102
+
103
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
104
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
105
+ /* the inline asm below is about 12% faster than the lookup method */
106
+ #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
107
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
108
+
109
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
110
+ register float f;
111
+ register double d;
112
+ __asm__(
113
+ "mtfprd %0,%2\n"
114
+ "xscvhpdp %0,%0\n"
115
+ "frsp %1,%0\n" :
116
+ /* temp */ "=d"(d),
117
+ /* out */ "=f"(f):
118
+ /* in */ "r"(h));
119
+ return f;
120
+ }
121
+
122
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
123
+ register double d;
124
+ register ggml_fp16_t r;
125
+ __asm__( /* xscvdphp can work on double or single precision */
126
+ "xscvdphp %0,%2\n"
127
+ "mffprd %1,%0\n" :
128
+ /* temp */ "=d"(d),
129
+ /* out */ "=r"(r):
130
+ /* in */ "f"(f));
131
+ return r;
132
+ }
133
+
134
+ #else
135
+
136
+ // FP16 <-> FP32
137
+ // ref: https://github.com/Maratyszcza/FP16
138
+
139
+ static inline float fp32_from_bits(uint32_t w) {
140
+ union {
141
+ uint32_t as_bits;
142
+ float as_value;
143
+ } fp32;
144
+ fp32.as_bits = w;
145
+ return fp32.as_value;
146
+ }
147
+
148
+ static inline uint32_t fp32_to_bits(float f) {
149
+ union {
150
+ float as_value;
151
+ uint32_t as_bits;
152
+ } fp32;
153
+ fp32.as_value = f;
154
+ return fp32.as_bits;
155
+ }
156
+
157
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
158
+ const uint32_t w = (uint32_t) h << 16;
159
+ const uint32_t sign = w & UINT32_C(0x80000000);
160
+ const uint32_t two_w = w + w;
161
+
162
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
163
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
164
+ const float exp_scale = 0x1.0p-112f;
165
+ #else
166
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
167
+ #endif
168
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
169
+
170
+ const uint32_t magic_mask = UINT32_C(126) << 23;
171
+ const float magic_bias = 0.5f;
172
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
173
+
174
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
175
+ const uint32_t result = sign |
176
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
177
+ return fp32_from_bits(result);
178
+ }
179
+
180
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
181
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
182
+ const float scale_to_inf = 0x1.0p+112f;
183
+ const float scale_to_zero = 0x1.0p-110f;
184
+ #else
185
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
186
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
187
+ #endif
188
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
189
+
190
+ const uint32_t w = fp32_to_bits(f);
191
+ const uint32_t shl1_w = w + w;
192
+ const uint32_t sign = w & UINT32_C(0x80000000);
193
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
194
+ if (bias < UINT32_C(0x71000000)) {
195
+ bias = UINT32_C(0x71000000);
196
+ }
197
+
198
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
199
+ const uint32_t bits = fp32_to_bits(base);
200
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
201
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
202
+ const uint32_t nonsign = exp_bits + mantissa_bits;
203
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
204
+ }
205
+
206
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
207
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
208
+
209
+ #endif // __F16C__
210
+
211
+ #endif // __ARM_NEON
212
+
213
+ // precomputed f32 table for f16 (256 KB)
214
+ // defined in ggml.c, initialized in ggml_init()
215
+ extern float ggml_table_f32_f16[1 << 16];
216
+
217
+ // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
218
+ // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
219
+ // This is also true for POWER9.
220
+ #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
221
+
222
+ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
223
+ uint16_t s;
224
+ memcpy(&s, &f, sizeof(uint16_t));
225
+ return ggml_table_f32_f16[s];
226
+ }
227
+
228
+ #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
229
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
230
+
231
+ #endif
232
+
233
+ #define GGML_HASHTABLE_FULL ((size_t)-1)
234
+ #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
235
+
236
+ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
237
+
238
+ // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
239
+ size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
240
+
241
+ // returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
242
+ size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
243
+
244
+ // return index, asserts if table is full
245
+ size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key);
246
+
247
+ #ifdef __cplusplus
248
+ }
249
+ #endif
bindings/ruby/ext/ggml-quants.c ADDED
The diff for this file is too large to render. See raw diff
 
bindings/ruby/ext/ggml-quants.h ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml-impl.h"
4
+
5
+ // GGML internal header
6
+
7
+ #include <stdint.h>
8
+ #include <stddef.h>
9
+
10
+ #define QK4_0 32
11
+ typedef struct {
12
+ ggml_fp16_t d; // delta
13
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
14
+ } block_q4_0;
15
+ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
16
+
17
+ #define QK4_1 32
18
+ typedef struct {
19
+ ggml_fp16_t d; // delta
20
+ ggml_fp16_t m; // min
21
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
22
+ } block_q4_1;
23
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
24
+
25
+ #define QK5_0 32
26
+ typedef struct {
27
+ ggml_fp16_t d; // delta
28
+ uint8_t qh[4]; // 5-th bit of quants
29
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
30
+ } block_q5_0;
31
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
32
+
33
+ #define QK5_1 32
34
+ typedef struct {
35
+ ggml_fp16_t d; // delta
36
+ ggml_fp16_t m; // min
37
+ uint8_t qh[4]; // 5-th bit of quants
38
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
39
+ } block_q5_1;
40
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
41
+
42
+ #define QK8_0 32
43
+ typedef struct {
44
+ ggml_fp16_t d; // delta
45
+ int8_t qs[QK8_0]; // quants
46
+ } block_q8_0;
47
+ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
48
+
49
+ #define QK8_1 32
50
+ typedef struct {
51
+ float d; // delta
52
+ float s; // d * sum(qs[i])
53
+ int8_t qs[QK8_1]; // quants
54
+ } block_q8_1;
55
+ static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
56
+
57
+ //
58
+ // Super-block quantization structures
59
+ //
60
+
61
+ // Super-block size
62
+ #ifdef GGML_QKK_64
63
+ #define QK_K 64
64
+ #define K_SCALE_SIZE 4
65
+ #else
66
+ #define QK_K 256
67
+ #define K_SCALE_SIZE 12
68
+ #endif
69
+
70
+ // 2-bit quantization
71
+ // weight is represented as x = a * q + b
72
+ // 16 blocks of 16 elements each
73
+ // Effectively 2.5625 bits per weight
74
+ typedef struct {
75
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
76
+ uint8_t qs[QK_K/4]; // quants
77
+ ggml_fp16_t d; // super-block scale for quantized scales
78
+ ggml_fp16_t dmin; // super-block scale for quantized mins
79
+ } block_q2_K;
80
+ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
81
+
82
+ // 3-bit quantization
83
+ // weight is represented as x = a * q
84
+ // 16 blocks of 16 elements each
85
+ // Effectively 3.4375 bits per weight
86
+ #ifdef GGML_QKK_64
87
+ typedef struct {
88
+ uint8_t hmask[QK_K/8]; // quants - high bit
89
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
90
+ uint8_t scales[2];
91
+ ggml_fp16_t d; // super-block scale
92
+ } block_q3_K;
93
+ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
94
+ #else
95
+ typedef struct {
96
+ uint8_t hmask[QK_K/8]; // quants - high bit
97
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
98
+ uint8_t scales[12]; // scales, quantized with 6 bits
99
+ ggml_fp16_t d; // super-block scale
100
+ } block_q3_K;
101
+ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
102
+ #endif
103
+
104
+ // 4-bit quantization
105
+ // 8 blocks of 32 elements each
106
+ // weight is represented as x = a * q + b
107
+ // Effectively 4.5 bits per weight
108
+ #ifdef GGML_QKK_64
109
+ typedef struct {
110
+ ggml_fp16_t d[2]; // super-block scales/mins
111
+ uint8_t scales[2]; // 4-bit block scales/mins
112
+ uint8_t qs[QK_K/2]; // 4--bit quants
113
+ } block_q4_K;
114
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
115
+ #else
116
+ typedef struct {
117
+ ggml_fp16_t d; // super-block scale for quantized scales
118
+ ggml_fp16_t dmin; // super-block scale for quantized mins
119
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
120
+ uint8_t qs[QK_K/2]; // 4--bit quants
121
+ } block_q4_K;
122
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
123
+ #endif
124
+
125
+ // 5-bit quantization
126
+ // 8 blocks of 32 elements each
127
+ // weight is represented as x = a * q + b
128
+ // Effectively 5.5 bits per weight
129
+ #ifdef GGML_QKK_64
130
+ typedef struct {
131
+ ggml_fp16_t d; // super-block scale
132
+ int8_t scales[QK_K/16]; // 8-bit block scales
133
+ uint8_t qh[QK_K/8]; // quants, high bit
134
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
135
+ } block_q5_K;
136
+ static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
137
+ #else
138
+ typedef struct {
139
+ ggml_fp16_t d; // super-block scale for quantized scales
140
+ ggml_fp16_t dmin; // super-block scale for quantized mins
141
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
142
+ uint8_t qh[QK_K/8]; // quants, high bit
143
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
144
+ } block_q5_K;
145
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
146
+ #endif
147
+
148
+ // 6-bit quantization
149
+ // weight is represented as x = a * q
150
+ // 16 blocks of 16 elements each
151
+ // Effectively 6.5625 bits per weight
152
+ typedef struct {
153
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
154
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
155
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
156
+ ggml_fp16_t d; // super-block scale
157
+ } block_q6_K;
158
+ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
159
+
160
+ // This is only used for intermediate quantization and dot products
161
+ typedef struct {
162
+ float d; // delta
163
+ int8_t qs[QK_K]; // quants
164
+ int16_t bsums[QK_K/16]; // sum of quants in groups of 16
165
+ } block_q8_K;
166
+ static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
167
+
168
+
169
+ // Quantization
170
+ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
171
+ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
172
+ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
173
+ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
174
+ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
175
+ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
176
+
177
+ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
178
+ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
179
+ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
180
+ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
181
+ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
182
+ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
183
+
184
+ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
185
+ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
186
+ void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
187
+ void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
188
+ void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
189
+ void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
190
+
191
+ void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
192
+ void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
193
+ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
194
+ void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
195
+ void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
196
+ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
197
+
198
+ // Dequantization
199
+ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
200
+ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
201
+ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
202
+ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
203
+ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
204
+ //void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
205
+
206
+ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
207
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
208
+ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
209
+ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
210
+ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
211
+ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
212
+
213
+ // Dot product
214
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
215
+ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
216
+ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
217
+ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
218
+ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
219
+
220
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
221
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
222
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
223
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
224
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
examples/common.cpp CHANGED
@@ -38,12 +38,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
38
  params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
39
  } else if (arg == "-t" || arg == "--threads") {
40
  params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
41
- } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
42
- params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
43
  } else if (arg == "-p" || arg == "--prompt") {
44
  params.prompt = get_next_arg(i, argc, argv, arg, params);
45
  } else if (arg == "-n" || arg == "--n_predict") {
46
  params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
 
 
47
  } else if (arg == "--top_k") {
48
  params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
49
  } else if (arg == "--top_p") {
@@ -56,6 +56,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
56
  params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
57
  } else if (arg == "-b" || arg == "--batch_size") {
58
  params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
 
 
 
 
 
 
59
  } else if (arg == "-m" || arg == "--model") {
60
  params.model = get_next_arg(i, argc, argv, arg, params);
61
  } else if (arg == "-i" || arg == "--interactive") {
@@ -97,7 +103,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
97
  fprintf(stderr, " -h, --help show this help message and exit\n");
98
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
99
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
100
- fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
101
  fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
102
  fprintf(stderr, " prompt to start generation with (default: random)\n");
103
  fprintf(stderr, " -f FNAME, --file FNAME\n");
@@ -111,6 +116,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
111
  fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
112
  fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
113
  fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
 
 
 
114
  fprintf(stderr, " -m FNAME, --model FNAME\n");
115
  fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
116
  fprintf(stderr, "\n");
 
38
  params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
39
  } else if (arg == "-t" || arg == "--threads") {
40
  params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
 
 
41
  } else if (arg == "-p" || arg == "--prompt") {
42
  params.prompt = get_next_arg(i, argc, argv, arg, params);
43
  } else if (arg == "-n" || arg == "--n_predict") {
44
  params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
45
+ } else if (arg == "-np" || arg == "--n_parallel") {
46
+ params.n_parallel = std::stoi(get_next_arg(i, argc, argv, arg, params));
47
  } else if (arg == "--top_k") {
48
  params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
49
  } else if (arg == "--top_p") {
 
56
  params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
57
  } else if (arg == "-b" || arg == "--batch_size") {
58
  params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
59
+ } else if (arg == "-c" || arg == "--context") {
60
+ params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params));
61
+ } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
62
+ params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
63
+ } else if (arg == "--ignore-eos") {
64
+ params.ignore_eos = true;
65
  } else if (arg == "-m" || arg == "--model") {
66
  params.model = get_next_arg(i, argc, argv, arg, params);
67
  } else if (arg == "-i" || arg == "--interactive") {
 
103
  fprintf(stderr, " -h, --help show this help message and exit\n");
104
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
105
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
 
106
  fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
107
  fprintf(stderr, " prompt to start generation with (default: random)\n");
108
  fprintf(stderr, " -f FNAME, --file FNAME\n");
 
116
  fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
117
  fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
118
  fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
119
+ fprintf(stderr, " -c N, --context N context / KV cache size (default: %d)\n", params.n_ctx);
120
+ fprintf(stderr, " --ignore-eos ignore EOS token during generation\n");
121
+ fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
122
  fprintf(stderr, " -m FNAME, --model FNAME\n");
123
  fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
124
  fprintf(stderr, "\n");
examples/common.h CHANGED
@@ -17,10 +17,15 @@
17
  //
18
 
19
  struct gpt_params {
20
- int32_t seed = -1; // RNG seed
21
- int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
22
- int32_t n_predict = 200; // new tokens to predict
23
- int32_t n_batch = 8; // batch size for prompt processing
 
 
 
 
 
24
 
25
  // sampling parameters
26
  int32_t top_k = 40;
@@ -35,8 +40,6 @@ struct gpt_params {
35
 
36
  bool interactive = false;
37
  int32_t interactive_port = -1;
38
-
39
- int32_t n_gpu_layers = 0;
40
  };
41
 
42
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 
17
  //
18
 
19
  struct gpt_params {
20
+ int32_t seed = -1; // RNG seed
21
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
22
+ int32_t n_predict = 200; // new tokens to predict
23
+ int32_t n_parallel = 1; // number of parallel streams
24
+ int32_t n_batch = 8; // batch size for prompt processing
25
+ int32_t n_ctx = 2048; // context size (this is the KV cache max size)
26
+ int32_t n_gpu_layers = 0; // number of layers to offlload to the GPU
27
+
28
+ bool ignore_eos = false; // ignore EOS token when generating text
29
 
30
  // sampling parameters
31
  int32_t top_k = 40;
 
40
 
41
  bool interactive = false;
42
  int32_t interactive_port = -1;
 
 
43
  };
44
 
45
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
examples/talk-llama/CMakeLists.txt CHANGED
@@ -7,7 +7,14 @@ if (WHISPER_SDL2)
7
 
8
  # TODO: this is temporary
9
  # need to export ggml symbols for MSVC, but too lazy ..
10
- add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp)
 
 
 
 
 
 
 
11
 
12
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
13
  target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
 
7
 
8
  # TODO: this is temporary
9
  # need to export ggml symbols for MSVC, but too lazy ..
10
+ add_executable(${TARGET}
11
+ talk-llama.cpp
12
+ llama.cpp
13
+ ../common.cpp
14
+ ../common-sdl.cpp
15
+ ../../ggml.c
16
+ ../../ggml-alloc.c
17
+ ../../whisper.cpp)
18
 
19
  target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
20
  target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -37,10 +37,12 @@
37
 
38
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
39
 
 
 
40
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
41
 
42
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43
- #define LLAMA_SESSION_VERSION 1
44
 
45
  #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
46
  // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@@ -60,13 +62,9 @@ extern "C" {
60
  struct llama_model;
61
  struct llama_context;
62
 
63
- typedef int llama_token;
64
-
65
- enum llama_log_level {
66
- LLAMA_LOG_LEVEL_ERROR = 2,
67
- LLAMA_LOG_LEVEL_WARN = 3,
68
- LLAMA_LOG_LEVEL_INFO = 4
69
- };
70
 
71
  enum llama_vocab_type {
72
  LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
@@ -86,28 +84,36 @@ extern "C" {
86
  // model file types
87
  enum llama_ftype {
88
  LLAMA_FTYPE_ALL_F32 = 0,
89
- LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
90
- LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
91
- LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
92
- LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
93
- // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
94
- // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
95
- LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
96
- LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
97
- LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
98
- LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
99
- LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
100
- LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
101
- LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
102
- LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
103
- LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
104
- LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
105
- LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
106
- LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
107
 
108
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
109
  };
110
 
 
 
 
 
 
 
 
 
111
  typedef struct llama_token_data {
112
  llama_token id; // token id
113
  float logit; // log-odds of the token
@@ -122,41 +128,75 @@ extern "C" {
122
 
123
  typedef void (*llama_progress_callback)(float progress, void *ctx);
124
 
125
- struct llama_context_params {
126
- uint32_t seed; // RNG seed, -1 for random
127
- int32_t n_ctx; // text context
128
- int32_t n_batch; // prompt processing batch size
129
- int32_t n_gpu_layers; // number of layers to store in VRAM
130
- int32_t main_gpu; // the GPU that is used for scratch and small tensors
131
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
133
 
134
- // ref: https://github.com/ggerganov/llama.cpp/pull/2054
135
- float rope_freq_base; // RoPE base frequency
136
- float rope_freq_scale; // RoPE frequency scaling factor
137
-
138
  // called with a progress value between 0 and 1, pass NULL to disable
139
  llama_progress_callback progress_callback;
140
  // context pointer passed to the progress callback
141
  void * progress_callback_user_data;
142
 
143
  // Keep the booleans together to avoid misalignment during copy-by-value.
144
- bool low_vram; // if true, reduce VRAM usage at the cost of performance
145
- bool mul_mat_q; // if true, use experimental mul_mat_q kernels
146
- bool f16_kv; // use fp16 for KV cache
147
- bool logits_all; // the llama_eval() call computes all logits, not just the last one
148
  bool vocab_only; // only load the vocabulary, no weights
149
  bool use_mmap; // use mmap if possible
150
  bool use_mlock; // force system to keep model in RAM
151
- bool embedding; // embedding mode only
152
  };
153
 
154
- // Signature for logging events
155
- // Note that text includes the new line character at the end for most events.
156
- // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
157
- // if it exists.
158
- // It might not exist for progress report where '.' is output repeatedly.
159
- typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  // model quantization parameters
162
  typedef struct llama_model_quantize_params {
@@ -165,6 +205,7 @@ extern "C" {
165
  bool allow_requantize; // allow quantizing non-f32/f16 tensors
166
  bool quantize_output_tensor; // quantize output.weight
167
  bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
 
168
  } llama_model_quantize_params;
169
 
170
  // grammar types
@@ -215,6 +256,8 @@ extern "C" {
215
  int32_t n_eval;
216
  };
217
 
 
 
218
  LLAMA_API struct llama_context_params llama_context_default_params(void);
219
  LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
220
 
@@ -228,7 +271,7 @@ extern "C" {
228
 
229
  LLAMA_API struct llama_model * llama_load_model_from_file(
230
  const char * path_model,
231
- struct llama_context_params params);
232
 
233
  LLAMA_API void llama_free_model(struct llama_model * model);
234
 
@@ -245,25 +288,31 @@ extern "C" {
245
  LLAMA_API bool llama_mmap_supported (void);
246
  LLAMA_API bool llama_mlock_supported(void);
247
 
248
- LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
 
249
  LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
250
- LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
251
- LLAMA_API int llama_n_embd (const struct llama_context * ctx);
252
 
253
- LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
 
 
 
 
254
 
255
- LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
256
- LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
257
- LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
258
- LLAMA_API int llama_model_n_embd (const struct llama_model * model);
259
 
260
  // Get a string describing the model type
261
  LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
 
262
  // Returns the total size of all the tensors in the model in bytes
263
  LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
 
264
  // Returns the total number of parameters in the model
265
  LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
266
 
 
 
 
267
  // Returns 0 on success
268
  LLAMA_API int llama_model_quantize(
269
  const char * fname_inp,
@@ -279,21 +328,70 @@ extern "C" {
279
  LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
280
  struct llama_context * ctx,
281
  const char * path_lora,
 
282
  const char * path_base_model,
283
  int n_threads),
284
- "please use llama_model_apply_lora_from_file instead");
285
 
286
  LLAMA_API int llama_model_apply_lora_from_file(
287
  const struct llama_model * model,
288
- const char * path_lora,
289
- const char * path_base_model,
290
- int n_threads);
 
 
 
 
 
291
 
292
  // Returns the number of tokens in the KV cache
293
- LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- // Sets the current rng seed.
296
- LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  // Returns the maximum size in bytes of the state (rng, logits, embedding
299
  // and kv_cache) - will often be smaller after compacting tokens
@@ -302,48 +400,104 @@ extern "C" {
302
  // Copies the state to the specified destination address.
303
  // Destination needs to have allocated enough memory.
304
  // Returns the number of bytes copied
305
- LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
 
 
306
 
307
  // Set the state reading from the specified address
308
  // Returns the number of bytes read
309
- LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src);
 
 
310
 
311
  // Save/load session file
312
- LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
313
- LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
 
 
 
 
314
 
315
- // Run the llama inference to obtain the logits and probabilities for the next token.
 
 
 
 
 
 
 
 
 
 
316
  // tokens + n_tokens is the provided batch of new tokens to process
317
  // n_past is the number of tokens to use from previous eval calls
318
  // Returns 0 on success
319
- LLAMA_API int llama_eval(
 
320
  struct llama_context * ctx,
321
- const llama_token * tokens,
322
- int n_tokens,
323
- int n_past,
324
- int n_threads);
325
 
326
  // Same as llama_eval, but use float matrix input directly.
327
- LLAMA_API int llama_eval_embd(
 
328
  struct llama_context * ctx,
329
- const float * embd,
330
- int n_tokens,
331
- int n_past,
332
- int n_threads);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- // Export a static computation graph for context of 511 and batch size of 1
335
- // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
336
- // parameters here to keep things simple
337
- // IMPORTANT: do not use for anything else other than debugging and testing!
338
- LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
339
 
340
  // Token logits obtained from the last call to llama_eval()
341
  // The logits for the last token are stored in the last row
342
- // Can be mutated in order to change the probabilities of the next token
343
- // Rows: n_tokens
344
  // Cols: n_vocab
345
  LLAMA_API float * llama_get_logits(struct llama_context * ctx);
346
 
 
 
 
 
347
  // Get the embeddings for the input
348
  // shape: [n_embd] (1-dimensional)
349
  LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
@@ -352,50 +506,47 @@ extern "C" {
352
  // Vocab
353
  //
354
 
355
- LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token);
356
 
357
- LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
358
 
359
- LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
360
 
361
  // Special tokens
362
- LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
363
- LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
364
- LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
 
 
 
 
 
 
365
 
366
  //
367
  // Tokenization
368
  //
369
 
370
- // Convert the provided text into tokens.
371
- // The tokens pointer must be large enough to hold the resulting tokens.
372
- // Returns the number of tokens on success, no more than n_max_tokens
373
- // Returns a negative number on failure - the number of tokens that would have been returned
 
 
374
  LLAMA_API int llama_tokenize(
375
- struct llama_context * ctx,
376
- const char * text,
377
- llama_token * tokens,
378
- int n_max_tokens,
379
- bool add_bos);
380
-
381
- LLAMA_API int llama_tokenize_with_model(
382
  const struct llama_model * model,
383
  const char * text,
 
384
  llama_token * tokens,
385
  int n_max_tokens,
386
- bool add_bos);
 
387
 
388
  // Token Id -> Piece.
389
  // Uses the vocabulary in the provided context.
390
  // Does not write null terminator to the buffer.
391
  // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
392
  LLAMA_API int llama_token_to_piece(
393
- const struct llama_context * ctx,
394
- llama_token token,
395
- char * buf,
396
- int length);
397
-
398
- LLAMA_API int llama_token_to_piece_with_model(
399
  const struct llama_model * model,
400
  llama_token token,
401
  char * buf,
@@ -418,11 +569,19 @@ extern "C" {
418
  // Sampling functions
419
  //
420
 
421
- /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
422
- LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
423
 
 
424
  /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
425
- LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
 
 
 
 
 
 
 
426
 
427
  /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
428
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
@@ -435,23 +594,61 @@ extern "C" {
435
  float scale);
436
 
437
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
438
- LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
 
 
439
 
440
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
441
- LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
 
 
 
 
442
 
443
  /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
444
- LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
447
- LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
 
 
 
 
448
 
449
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
450
- LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
451
- LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  /// @details Apply constraints from grammar
454
- LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
 
 
 
455
 
456
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
457
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@@ -459,23 +656,42 @@ extern "C" {
459
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
460
  /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
461
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
462
- LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
 
 
 
 
 
 
463
 
464
  /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
465
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
466
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
467
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
468
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
469
- LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
 
 
 
 
 
470
 
471
  /// @details Selects the token with the highest probability.
472
- LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
 
 
 
473
 
474
  /// @details Randomly selects a token from the candidates based on their probabilities.
475
- LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
 
 
476
 
477
  /// @details Accepts the sampled token into the grammar
478
- LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
 
 
 
479
 
480
  //
481
  // Beam search
@@ -483,9 +699,10 @@ extern "C" {
483
 
484
  struct llama_beam_view {
485
  const llama_token * tokens;
 
486
  size_t n_tokens;
487
- float p; // Cumulative beam probability (renormalized relative to all beams)
488
- bool eob; // Callback should set this to true when a beam is at end-of-beam.
489
  };
490
 
491
  // Passed to beam_search_callback function.
@@ -494,9 +711,10 @@ extern "C" {
494
  // These pointers are valid only during the synchronous callback, so should not be saved.
495
  struct llama_beams_state {
496
  struct llama_beam_view * beam_views;
 
497
  size_t n_beams; // Number of elements in beam_views[].
498
  size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
499
- bool last_call; // True iff this is the last callback invocation.
500
  };
501
 
502
  // Type of pointer to the beam_search_callback function.
@@ -511,11 +729,17 @@ extern "C" {
511
  /// @param n_beams Number of beams to use.
512
  /// @param n_past Number of tokens already evaluated.
513
  /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
514
- /// @param n_threads Number of threads as passed to llama_eval().
515
- LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
 
 
 
 
 
516
 
517
  // Performance information
518
  LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
 
519
  LLAMA_API void llama_print_timings(struct llama_context * ctx);
520
  LLAMA_API void llama_reset_timings(struct llama_context * ctx);
521
 
@@ -524,7 +748,7 @@ extern "C" {
524
 
525
  // Set callback for all future logging events.
526
  // If this is not called, or NULL is supplied, everything is output on stderr.
527
- LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data);
528
 
529
  LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
530
 
@@ -540,7 +764,9 @@ extern "C" {
540
 
541
  struct ggml_tensor;
542
 
543
- const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
 
 
544
 
545
  #endif // LLAMA_API_INTERNAL
546
 
 
37
 
38
  #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
39
 
40
+ #define LLAMA_MAX_RNG_STATE (64*1024)
41
+
42
  #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
43
 
44
  #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
45
+ #define LLAMA_SESSION_VERSION 2
46
 
47
  #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
48
  // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
 
62
  struct llama_model;
63
  struct llama_context;
64
 
65
+ typedef int32_t llama_pos;
66
+ typedef int32_t llama_token;
67
+ typedef int32_t llama_seq_id;
 
 
 
 
68
 
69
  enum llama_vocab_type {
70
  LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
 
84
  // model file types
85
  enum llama_ftype {
86
  LLAMA_FTYPE_ALL_F32 = 0,
87
+ LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
88
+ LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
89
+ LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
90
+ LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
91
+ // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
92
+ // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
93
+ LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
94
+ LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
95
+ LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
96
+ LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
97
+ LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors
98
+ LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors
99
+ LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors
100
+ LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors
101
+ LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors
102
+ LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
103
+ LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
104
+ LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
105
 
106
  LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
107
  };
108
 
109
+ enum llama_rope_scaling_type {
110
+ LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
111
+ LLAMA_ROPE_SCALING_NONE = 0,
112
+ LLAMA_ROPE_SCALING_LINEAR = 1,
113
+ LLAMA_ROPE_SCALING_YARN = 2,
114
+ LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
115
+ };
116
+
117
  typedef struct llama_token_data {
118
  llama_token id; // token id
119
  float logit; // log-odds of the token
 
128
 
129
  typedef void (*llama_progress_callback)(float progress, void *ctx);
130
 
131
+ // Input data for llama_decode
132
+ // A llama_batch object can contain input about one or many sequences
133
+ // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
134
+ //
135
+ // - token : the token ids of the input (used when embd is NULL)
136
+ // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
137
+ // - pos : the positions of the respective token in the sequence
138
+ // - seq_id : the sequence to which the respective token belongs
139
+ // - logits : if zero, the logits for the respective token will not be output
140
+ //
141
+ typedef struct llama_batch {
142
+ int32_t n_tokens;
143
+
144
+ llama_token * token;
145
+ float * embd;
146
+ llama_pos * pos;
147
+ int32_t * n_seq_id;
148
+ llama_seq_id ** seq_id;
149
+ int8_t * logits;
150
+
151
+ // NOTE: helpers for smooth API transition - can be deprecated in the future
152
+ // for future-proof code, use the above fields instead and ignore everything below
153
+ //
154
+ // pos[i] = all_pos_0 + i*all_pos_1
155
+ //
156
+ llama_pos all_pos_0; // used if pos == NULL
157
+ llama_pos all_pos_1; // used if pos == NULL
158
+ llama_seq_id all_seq_id; // used if seq_id == NULL
159
+ } llama_batch;
160
+
161
+ struct llama_model_params {
162
+ int32_t n_gpu_layers; // number of layers to store in VRAM
163
+ int32_t main_gpu; // the GPU that is used for scratch and small tensors
164
  const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
165
 
 
 
 
 
166
  // called with a progress value between 0 and 1, pass NULL to disable
167
  llama_progress_callback progress_callback;
168
  // context pointer passed to the progress callback
169
  void * progress_callback_user_data;
170
 
171
  // Keep the booleans together to avoid misalignment during copy-by-value.
 
 
 
 
172
  bool vocab_only; // only load the vocabulary, no weights
173
  bool use_mmap; // use mmap if possible
174
  bool use_mlock; // force system to keep model in RAM
 
175
  };
176
 
177
+ struct llama_context_params {
178
+ uint32_t seed; // RNG seed, -1 for random
179
+ uint32_t n_ctx; // text context, 0 = from model
180
+ uint32_t n_batch; // prompt processing maximum batch size
181
+ uint32_t n_threads; // number of threads to use for generation
182
+ uint32_t n_threads_batch; // number of threads to use for batch processing
183
+ int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
184
+
185
+ // ref: https://github.com/ggerganov/llama.cpp/pull/2054
186
+ float rope_freq_base; // RoPE base frequency, 0 = from model
187
+ float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
188
+ float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
189
+ float yarn_attn_factor; // YaRN magnitude scaling factor
190
+ float yarn_beta_fast; // YaRN low correction dim
191
+ float yarn_beta_slow; // YaRN high correction dim
192
+ uint32_t yarn_orig_ctx; // YaRN original context size
193
+
194
+ // Keep the booleans together to avoid misalignment during copy-by-value.
195
+ bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
196
+ bool f16_kv; // use fp16 for KV cache, fp32 otherwise
197
+ bool logits_all; // the llama_eval() call computes all logits, not just the last one
198
+ bool embedding; // embedding mode only
199
+ };
200
 
201
  // model quantization parameters
202
  typedef struct llama_model_quantize_params {
 
205
  bool allow_requantize; // allow quantizing non-f32/f16 tensors
206
  bool quantize_output_tensor; // quantize output.weight
207
  bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
208
+ bool pure; // disable k-quant mixtures and quantize all tensors to the same type
209
  } llama_model_quantize_params;
210
 
211
  // grammar types
 
256
  int32_t n_eval;
257
  };
258
 
259
+ // Helpers for getting default parameters
260
+ LLAMA_API struct llama_model_params llama_model_default_params(void);
261
  LLAMA_API struct llama_context_params llama_context_default_params(void);
262
  LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
263
 
 
271
 
272
  LLAMA_API struct llama_model * llama_load_model_from_file(
273
  const char * path_model,
274
+ struct llama_model_params params);
275
 
276
  LLAMA_API void llama_free_model(struct llama_model * model);
277
 
 
288
  LLAMA_API bool llama_mmap_supported (void);
289
  LLAMA_API bool llama_mlock_supported(void);
290
 
291
+ LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
292
+
293
  LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
 
 
294
 
295
+ LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
296
+
297
+ LLAMA_API int llama_n_vocab (const struct llama_model * model);
298
+ LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
299
+ LLAMA_API int llama_n_embd (const struct llama_model * model);
300
 
301
+ // Get the model's RoPE frequency scaling factor
302
+ LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
 
 
303
 
304
  // Get a string describing the model type
305
  LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
306
+
307
  // Returns the total size of all the tensors in the model in bytes
308
  LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
309
+
310
  // Returns the total number of parameters in the model
311
  LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
312
 
313
+ // Get a llama model tensor
314
+ LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
315
+
316
  // Returns 0 on success
317
  LLAMA_API int llama_model_quantize(
318
  const char * fname_inp,
 
328
  LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
329
  struct llama_context * ctx,
330
  const char * path_lora,
331
+ float scale,
332
  const char * path_base_model,
333
  int n_threads),
334
+ "use llama_model_apply_lora_from_file instead");
335
 
336
  LLAMA_API int llama_model_apply_lora_from_file(
337
  const struct llama_model * model,
338
+ const char * path_lora,
339
+ float scale,
340
+ const char * path_base_model,
341
+ int n_threads);
342
+
343
+ //
344
+ // KV cache
345
+ //
346
 
347
  // Returns the number of tokens in the KV cache
348
+ LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
349
+ "avoid using this, it will be removed in the future, instead - count the tokens in user code");
350
+
351
+ // Clear the KV cache
352
+ LLAMA_API void llama_kv_cache_clear(
353
+ struct llama_context * ctx);
354
+
355
+ // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
356
+ // seq_id < 0 : match any sequence
357
+ // p0 < 0 : [0, p1]
358
+ // p1 < 0 : [p0, inf)
359
+ LLAMA_API void llama_kv_cache_seq_rm(
360
+ struct llama_context * ctx,
361
+ llama_seq_id seq_id,
362
+ llama_pos p0,
363
+ llama_pos p1);
364
+
365
+ // Copy all tokens that belong to the specified sequence to another sequence
366
+ // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
367
+ // p0 < 0 : [0, p1]
368
+ // p1 < 0 : [p0, inf)
369
+ LLAMA_API void llama_kv_cache_seq_cp(
370
+ struct llama_context * ctx,
371
+ llama_seq_id seq_id_src,
372
+ llama_seq_id seq_id_dst,
373
+ llama_pos p0,
374
+ llama_pos p1);
375
 
376
+ // Removes all tokens that do not belong to the specified sequence
377
+ LLAMA_API void llama_kv_cache_seq_keep(
378
+ struct llama_context * ctx,
379
+ llama_seq_id seq_id);
380
+
381
+ // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
382
+ // If the KV cache is RoPEd, the KV data is updated accordingly
383
+ // p0 < 0 : [0, p1]
384
+ // p1 < 0 : [p0, inf)
385
+ LLAMA_API void llama_kv_cache_seq_shift(
386
+ struct llama_context * ctx,
387
+ llama_seq_id seq_id,
388
+ llama_pos p0,
389
+ llama_pos p1,
390
+ llama_pos delta);
391
+
392
+ //
393
+ // State / sessions
394
+ //
395
 
396
  // Returns the maximum size in bytes of the state (rng, logits, embedding
397
  // and kv_cache) - will often be smaller after compacting tokens
 
400
  // Copies the state to the specified destination address.
401
  // Destination needs to have allocated enough memory.
402
  // Returns the number of bytes copied
403
+ LLAMA_API size_t llama_copy_state_data(
404
+ struct llama_context * ctx,
405
+ uint8_t * dst);
406
 
407
  // Set the state reading from the specified address
408
  // Returns the number of bytes read
409
+ LLAMA_API size_t llama_set_state_data(
410
+ struct llama_context * ctx,
411
+ uint8_t * src);
412
 
413
  // Save/load session file
414
+ LLAMA_API bool llama_load_session_file(
415
+ struct llama_context * ctx,
416
+ const char * path_session,
417
+ llama_token * tokens_out,
418
+ size_t n_token_capacity,
419
+ size_t * n_token_count_out);
420
 
421
+ LLAMA_API bool llama_save_session_file(
422
+ struct llama_context * ctx,
423
+ const char * path_session,
424
+ const llama_token * tokens,
425
+ size_t n_token_count);
426
+
427
+ //
428
+ // Decoding
429
+ //
430
+
431
+ // Run the llama inference to obtain the logits and probabilities for the next token(s).
432
  // tokens + n_tokens is the provided batch of new tokens to process
433
  // n_past is the number of tokens to use from previous eval calls
434
  // Returns 0 on success
435
+ // DEPRECATED: use llama_decode() instead
436
+ LLAMA_API DEPRECATED(int llama_eval(
437
  struct llama_context * ctx,
438
+ llama_token * tokens,
439
+ int32_t n_tokens,
440
+ int n_past),
441
+ "use llama_decode() instead");
442
 
443
  // Same as llama_eval, but use float matrix input directly.
444
+ // DEPRECATED: use llama_decode() instead
445
+ LLAMA_API DEPRECATED(int llama_eval_embd(
446
  struct llama_context * ctx,
447
+ float * embd,
448
+ int32_t n_tokens,
449
+ int n_past),
450
+ "use llama_decode() instead");
451
+
452
+ // Return batch for single sequence of tokens starting at pos_0
453
+ //
454
+ // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
455
+ //
456
+ LLAMA_API struct llama_batch llama_batch_get_one(
457
+ llama_token * tokens,
458
+ int32_t n_tokens,
459
+ llama_pos pos_0,
460
+ llama_seq_id seq_id);
461
+
462
+ // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
463
+ // Each token can be assigned up to n_seq_max sequence ids
464
+ // The batch has to be freed with llama_batch_free()
465
+ // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
466
+ // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
467
+ // The rest of the llama_batch members are allocated with size n_tokens
468
+ // All members are left uninitialized
469
+ LLAMA_API struct llama_batch llama_batch_init(
470
+ int32_t n_tokens,
471
+ int32_t embd,
472
+ int32_t n_seq_max);
473
+
474
+ // Frees a batch of tokens allocated with llama_batch_init()
475
+ LLAMA_API void llama_batch_free(struct llama_batch batch);
476
+
477
+ // Positive return values does not mean a fatal error, but rather a warning.
478
+ // 0 - success
479
+ // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
480
+ // < 0 - error
481
+ LLAMA_API int llama_decode(
482
+ struct llama_context * ctx,
483
+ struct llama_batch batch);
484
 
485
+ // Set the number of threads used for decoding
486
+ // n_threads is the number of threads used for generation (single token)
487
+ // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
488
+ LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
 
489
 
490
  // Token logits obtained from the last call to llama_eval()
491
  // The logits for the last token are stored in the last row
492
+ // Logits for which llama_batch.logits[i] == 0 are undefined
493
+ // Rows: n_tokens provided with llama_batch
494
  // Cols: n_vocab
495
  LLAMA_API float * llama_get_logits(struct llama_context * ctx);
496
 
497
+ // Logits for the ith token. Equivalent to:
498
+ // llama_get_logits(ctx) + i*n_vocab
499
+ LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
500
+
501
  // Get the embeddings for the input
502
  // shape: [n_embd] (1-dimensional)
503
  LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
506
  // Vocab
507
  //
508
 
509
+ LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
510
 
511
+ LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
512
 
513
+ LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
514
 
515
  // Special tokens
516
+ LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
517
+ LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
518
+ LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
519
+
520
+ // codellama infill tokens
521
+ LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
522
+ LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
523
+ LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
524
+ LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
525
 
526
  //
527
  // Tokenization
528
  //
529
 
530
+ /// @details Convert the provided text into tokens.
531
+ /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
532
+ /// @return Returns the number of tokens on success, no more than n_max_tokens
533
+ /// @return Returns a negative number on failure - the number of tokens that would have been returned
534
+ /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
535
+ /// Does not insert a leading space.
536
  LLAMA_API int llama_tokenize(
 
 
 
 
 
 
 
537
  const struct llama_model * model,
538
  const char * text,
539
+ int text_len,
540
  llama_token * tokens,
541
  int n_max_tokens,
542
+ bool add_bos,
543
+ bool special);
544
 
545
  // Token Id -> Piece.
546
  // Uses the vocabulary in the provided context.
547
  // Does not write null terminator to the buffer.
548
  // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
549
  LLAMA_API int llama_token_to_piece(
 
 
 
 
 
 
550
  const struct llama_model * model,
551
  llama_token token,
552
  char * buf,
 
569
  // Sampling functions
570
  //
571
 
572
+ // Sets the current rng seed.
573
+ LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
574
 
575
+ /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
576
  /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
577
+ LLAMA_API void llama_sample_repetition_penalties(
578
+ struct llama_context * ctx,
579
+ llama_token_data_array * candidates,
580
+ const llama_token * last_tokens,
581
+ size_t penalty_last_n,
582
+ float penalty_repeat,
583
+ float penalty_freq,
584
+ float penalty_present);
585
 
586
  /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
587
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
 
594
  float scale);
595
 
596
  /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
597
+ LLAMA_API void llama_sample_softmax(
598
+ struct llama_context * ctx,
599
+ llama_token_data_array * candidates);
600
 
601
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
602
+ LLAMA_API void llama_sample_top_k(
603
+ struct llama_context * ctx,
604
+ llama_token_data_array * candidates,
605
+ int k,
606
+ size_t min_keep);
607
 
608
  /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
609
+ LLAMA_API void llama_sample_top_p(
610
+ struct llama_context * ctx,
611
+ llama_token_data_array * candidates,
612
+ float p,
613
+ size_t min_keep);
614
+
615
+ /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
616
+ LLAMA_API void llama_sample_min_p(
617
+ struct llama_context * ctx,
618
+ llama_token_data_array * candidates,
619
+ float p,
620
+ size_t min_keep);
621
 
622
  /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
623
+ LLAMA_API void llama_sample_tail_free(
624
+ struct llama_context * ctx,
625
+ llama_token_data_array * candidates,
626
+ float z,
627
+ size_t min_keep);
628
 
629
  /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
630
+ LLAMA_API void llama_sample_typical(
631
+ struct llama_context * ctx,
632
+ llama_token_data_array * candidates,
633
+ float p,
634
+ size_t min_keep);
635
+
636
+ LLAMA_API void llama_sample_temp(
637
+ struct llama_context * ctx,
638
+ llama_token_data_array * candidates,
639
+ float temp);
640
+
641
+ LLAMA_API DEPRECATED(void llama_sample_temperature(
642
+ struct llama_context * ctx,
643
+ llama_token_data_array * candidates,
644
+ float temp),
645
+ "use llama_sample_temp instead");
646
 
647
  /// @details Apply constraints from grammar
648
+ LLAMA_API void llama_sample_grammar(
649
+ struct llama_context * ctx,
650
+ llama_token_data_array * candidates,
651
+ const struct llama_grammar * grammar);
652
 
653
  /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
654
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
 
656
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
657
  /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
658
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
659
+ LLAMA_API llama_token llama_sample_token_mirostat(
660
+ struct llama_context * ctx,
661
+ llama_token_data_array * candidates,
662
+ float tau,
663
+ float eta,
664
+ int m,
665
+ float * mu);
666
 
667
  /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
668
  /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
669
  /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
670
  /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
671
  /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
672
+ LLAMA_API llama_token llama_sample_token_mirostat_v2(
673
+ struct llama_context * ctx,
674
+ llama_token_data_array * candidates,
675
+ float tau,
676
+ float eta,
677
+ float * mu);
678
 
679
  /// @details Selects the token with the highest probability.
680
+ /// Does not compute the token probabilities. Use llama_sample_softmax() instead.
681
+ LLAMA_API llama_token llama_sample_token_greedy(
682
+ struct llama_context * ctx,
683
+ llama_token_data_array * candidates);
684
 
685
  /// @details Randomly selects a token from the candidates based on their probabilities.
686
+ LLAMA_API llama_token llama_sample_token(
687
+ struct llama_context * ctx,
688
+ llama_token_data_array * candidates);
689
 
690
  /// @details Accepts the sampled token into the grammar
691
+ LLAMA_API void llama_grammar_accept_token(
692
+ struct llama_context * ctx,
693
+ struct llama_grammar * grammar,
694
+ llama_token token);
695
 
696
  //
697
  // Beam search
 
699
 
700
  struct llama_beam_view {
701
  const llama_token * tokens;
702
+
703
  size_t n_tokens;
704
+ float p; // Cumulative beam probability (renormalized relative to all beams)
705
+ bool eob; // Callback should set this to true when a beam is at end-of-beam.
706
  };
707
 
708
  // Passed to beam_search_callback function.
 
711
  // These pointers are valid only during the synchronous callback, so should not be saved.
712
  struct llama_beams_state {
713
  struct llama_beam_view * beam_views;
714
+
715
  size_t n_beams; // Number of elements in beam_views[].
716
  size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
717
+ bool last_call; // True iff this is the last callback invocation.
718
  };
719
 
720
  // Type of pointer to the beam_search_callback function.
 
729
  /// @param n_beams Number of beams to use.
730
  /// @param n_past Number of tokens already evaluated.
731
  /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
732
+ LLAMA_API void llama_beam_search(
733
+ struct llama_context * ctx,
734
+ llama_beam_search_callback_fn_t callback,
735
+ void * callback_data,
736
+ size_t n_beams,
737
+ int n_past,
738
+ int n_predict);
739
 
740
  // Performance information
741
  LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
742
+
743
  LLAMA_API void llama_print_timings(struct llama_context * ctx);
744
  LLAMA_API void llama_reset_timings(struct llama_context * ctx);
745
 
 
748
 
749
  // Set callback for all future logging events.
750
  // If this is not called, or NULL is supplied, everything is output on stderr.
751
+ LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
752
 
753
  LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
754
 
 
764
 
765
  struct ggml_tensor;
766
 
767
+ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
768
+ struct llama_context * ctx
769
+ );
770
 
771
  #endif // LLAMA_API_INTERNAL
772
 
examples/talk-llama/talk-llama.cpp CHANGED
@@ -16,21 +16,28 @@
16
  #include <regex>
17
 
18
  std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
19
- // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
20
- std::vector<llama_token> res(text.size() + (int)add_bos);
21
- int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
22
- assert(n >= 0);
23
- res.resize(n);
24
 
25
- return res;
 
 
 
 
 
 
 
 
 
 
 
26
  }
27
 
28
  std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
29
  std::vector<char> result(8, 0);
30
- const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
31
  if (n_tokens < 0) {
32
  result.resize(-n_tokens);
33
- int check = llama_token_to_piece(ctx, token, result.data(), result.size());
34
  GGML_ASSERT(check == -n_tokens);
35
  } else {
36
  result.resize(n_tokens);
@@ -251,16 +258,19 @@ int main(int argc, char ** argv) {
251
 
252
  llama_backend_init(true);
253
 
254
- auto lparams = llama_context_default_params();
255
 
256
- // tune these to your liking
257
- lparams.n_ctx = 2048;
258
- lparams.seed = 1;
259
- lparams.f16_kv = true;
260
 
261
- struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lparams);
 
 
 
 
 
 
262
 
263
- struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lparams);
264
 
265
  // print some info about the processing
266
  {
@@ -356,7 +366,7 @@ int main(int argc, char ** argv) {
356
  if (fp != NULL) {
357
  std::fclose(fp);
358
 
359
- session_tokens.resize(lparams.n_ctx);
360
  size_t n_token_count_out = 0;
361
  if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
362
  fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
@@ -378,7 +388,7 @@ int main(int argc, char ** argv) {
378
  printf("\n");
379
  printf("%s : initializing - please wait ...\n", __func__);
380
 
381
- if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
382
  fprintf(stderr, "%s : failed to eval\n", __func__);
383
  return 1;
384
  }
@@ -561,7 +571,7 @@ int main(int argc, char ** argv) {
561
  n_session_consumed = session_tokens.size();
562
  }
563
 
564
- if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
565
  fprintf(stderr, "%s : failed to eval\n", __func__);
566
  return 1;
567
  }
@@ -593,9 +603,9 @@ int main(int argc, char ** argv) {
593
 
594
  {
595
  auto logits = llama_get_logits(ctx_llama);
596
- auto n_vocab = llama_n_vocab(ctx_llama);
597
 
598
- logits[llama_token_eos(ctx_llama)] = 0;
599
 
600
  std::vector<llama_token_data> candidates;
601
  candidates.reserve(n_vocab);
@@ -606,13 +616,13 @@ int main(int argc, char ** argv) {
606
  llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
607
 
608
  // apply repeat penalty
609
- const float nl_logit = logits[llama_token_nl(ctx_llama)];
610
 
611
- llama_sample_repetition_penalty(ctx_llama, &candidates_p,
612
  embd_inp.data() + std::max(0, n_past - repeat_last_n),
613
- repeat_last_n, repeat_penalty);
614
 
615
- logits[llama_token_nl(ctx_llama)] = nl_logit;
616
 
617
  if (temp <= 0) {
618
  // Greedy sampling
@@ -621,12 +631,12 @@ int main(int argc, char ** argv) {
621
  // Temperature sampling
622
  llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
623
  llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
624
- llama_sample_temperature(ctx_llama, &candidates_p, temp);
625
  id = llama_sample_token(ctx_llama, &candidates_p);
626
  }
627
  }
628
 
629
- if (id != llama_token_eos(ctx_llama)) {
630
  // add it to the context
631
  embd.push_back(id);
632
 
 
16
  #include <regex>
17
 
18
  std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
19
+ auto * model = llama_get_model(ctx);
 
 
 
 
20
 
21
+ // upper limit for the number of tokens
22
+ int n_tokens = text.length() + add_bos;
23
+ std::vector<llama_token> result(n_tokens);
24
+ n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, false);
25
+ if (n_tokens < 0) {
26
+ result.resize(-n_tokens);
27
+ int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, false);
28
+ GGML_ASSERT(check == -n_tokens);
29
+ } else {
30
+ result.resize(n_tokens);
31
+ }
32
+ return result;
33
  }
34
 
35
  std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
36
  std::vector<char> result(8, 0);
37
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
38
  if (n_tokens < 0) {
39
  result.resize(-n_tokens);
40
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
41
  GGML_ASSERT(check == -n_tokens);
42
  } else {
43
  result.resize(n_tokens);
 
258
 
259
  llama_backend_init(true);
260
 
261
+ auto lmparams = llama_model_default_params();
262
 
263
+ struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
 
 
 
264
 
265
+ llama_context_params lcparams = llama_context_default_params();
266
+
267
+ // tune these to your liking
268
+ lcparams.n_ctx = 2048;
269
+ lcparams.seed = 1;
270
+ lcparams.f16_kv = true;
271
+ lcparams.n_threads = params.n_threads;
272
 
273
+ struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
274
 
275
  // print some info about the processing
276
  {
 
366
  if (fp != NULL) {
367
  std::fclose(fp);
368
 
369
+ session_tokens.resize(llama_n_ctx(ctx_llama));
370
  size_t n_token_count_out = 0;
371
  if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
372
  fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
 
388
  printf("\n");
389
  printf("%s : initializing - please wait ...\n", __func__);
390
 
391
+ if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) {
392
  fprintf(stderr, "%s : failed to eval\n", __func__);
393
  return 1;
394
  }
 
571
  n_session_consumed = session_tokens.size();
572
  }
573
 
574
+ if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) {
575
  fprintf(stderr, "%s : failed to eval\n", __func__);
576
  return 1;
577
  }
 
603
 
604
  {
605
  auto logits = llama_get_logits(ctx_llama);
606
+ auto n_vocab = llama_n_vocab(model_llama);
607
 
608
+ logits[llama_token_eos(model_llama)] = 0;
609
 
610
  std::vector<llama_token_data> candidates;
611
  candidates.reserve(n_vocab);
 
616
  llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
617
 
618
  // apply repeat penalty
619
+ const float nl_logit = logits[llama_token_nl(model_llama)];
620
 
621
+ llama_sample_repetition_penalties(ctx_llama, &candidates_p,
622
  embd_inp.data() + std::max(0, n_past - repeat_last_n),
623
+ repeat_last_n, repeat_penalty, 0.0, 0.0f);
624
 
625
+ logits[llama_token_nl(model_llama)] = nl_logit;
626
 
627
  if (temp <= 0) {
628
  // Greedy sampling
 
631
  // Temperature sampling
632
  llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
633
  llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
634
+ llama_sample_temp (ctx_llama, &candidates_p, temp);
635
  id = llama_sample_token(ctx_llama, &candidates_p);
636
  }
637
  }
638
 
639
+ if (id != llama_token_eos(model_llama)) {
640
  // add it to the context
641
  embd.push_back(id);
642
 
examples/talk-llama/unicode.h ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cassert>
4
+ #include <stdexcept>
5
+ #include <vector>
6
+ #include <unordered_map>
7
+
8
+ static const std::vector<std::pair<uint32_t, uint32_t>> digit_ranges = {
9
+ {0x30, 0x39}, {0xB2, 0xB3}, {0xB9, 0xB9}, {0x660, 0x669}, {0x6F0, 0x6F9}, {0x7C0, 0x7C9}, {0x966, 0x96F}, {0x9E6, 0x9EF}, {0xA66, 0xA6F}, {0xAE6, 0xAEF}, {0xB66, 0xB6F}, {0xBE6, 0xBEF}, {0xC66, 0xC6F},
10
+ {0xCE6, 0xCEF}, {0xD66, 0xD6F}, {0xDE6, 0xDEF}, {0xE50, 0xE59}, {0xED0, 0xED9}, {0xF20, 0xF29}, {0x1040, 0x1049}, {0x1090, 0x1099}, {0x1369, 0x1371}, {0x17E0, 0x17E9}, {0x1810, 0x1819}, {0x1946, 0x194F},
11
+ {0x19D0, 0x19DA}, {0x1A80, 0x1A89}, {0x1A90, 0x1A99}, {0x1B50, 0x1B59}, {0x1BB0, 0x1BB9}, {0x1C40, 0x1C49}, {0x1C50, 0x1C59}, {0x2070, 0x2070}, {0x2074, 0x2079}, {0x2080, 0x2089}, {0x2460, 0x2468},
12
+ {0x2474, 0x247C}, {0x2488, 0x2490}, {0x24EA, 0x24EA}, {0x24F5, 0x24FD}, {0x24FF, 0x24FF}, {0x2776, 0x277E}, {0x2780, 0x2788}, {0x278A, 0x2792}, {0xA620, 0xA629}, {0xA8D0, 0xA8D9}, {0xA900, 0xA909},
13
+ {0xA9D0, 0xA9D9}, {0xA9F0, 0xA9F9}, {0xAA50, 0xAA59}, {0xABF0, 0xABF9}, {0xFF10, 0xFF19}, {0x104A0, 0x104A9}, {0x10A40, 0x10A43}, {0x10D30, 0x10D39}, {0x10E60, 0x10E68}, {0x11052, 0x1105A},
14
+ {0x11066, 0x1106F}, {0x110F0, 0x110F9}, {0x11136, 0x1113F}, {0x111D0, 0x111D9}, {0x112F0, 0x112F9}, {0x11450, 0x11459}, {0x114D0, 0x114D9}, {0x11650, 0x11659}, {0x116C0, 0x116C9}, {0x11730, 0x11739},
15
+ {0x118E0, 0x118E9}, {0x11950, 0x11959}, {0x11C50, 0x11C59}, {0x11D50, 0x11D59}, {0x11DA0, 0x11DA9}, {0x16A60, 0x16A69}, {0x16B50, 0x16B59}, {0x1D7CE, 0x1D7FF}, {0x1E140, 0x1E149}, {0x1E2F0, 0x1E2F9},
16
+ {0x1E950, 0x1E959}, {0x1F100, 0x1F10A}, {0x1FBF0, 0x1FBF9},
17
+ };
18
+
19
+ static const std::vector<std::pair<uint32_t, uint32_t>> letter_ranges = {
20
+ {0x41, 0x5A}, {0x61, 0x7A}, {0xAA, 0xAA}, {0xB5, 0xB5}, {0xBA, 0xBA}, {0xC0, 0xD6}, {0xD8, 0xF6}, {0xF8, 0x2C1}, {0x2C6, 0x2D1}, {0x2E0, 0x2E4}, {0x2EC, 0x2EC}, {0x2EE, 0x2EE}, {0x370, 0x374},
21
+ {0x376, 0x377}, {0x37A, 0x37D}, {0x37F, 0x37F}, {0x386, 0x386}, {0x388, 0x38A}, {0x38C, 0x38C}, {0x38E, 0x3A1}, {0x3A3, 0x3F5}, {0x3F7, 0x481}, {0x48A, 0x52F}, {0x531, 0x556}, {0x559, 0x559},
22
+ {0x560, 0x588}, {0x5D0, 0x5EA}, {0x5EF, 0x5F2}, {0x620, 0x64A}, {0x66E, 0x66F}, {0x671, 0x6D3}, {0x6D5, 0x6D5}, {0x6E5, 0x6E6}, {0x6EE, 0x6EF}, {0x6FA, 0x6FC}, {0x6FF, 0x6FF}, {0x710, 0x710},
23
+ {0x712, 0x72F}, {0x74D, 0x7A5}, {0x7B1, 0x7B1}, {0x7CA, 0x7EA}, {0x7F4, 0x7F5}, {0x7FA, 0x7FA}, {0x800, 0x815}, {0x81A, 0x81A}, {0x824, 0x824}, {0x828, 0x828}, {0x840, 0x858}, {0x860, 0x86A},
24
+ {0x8A0, 0x8B4}, {0x8B6, 0x8C7}, {0x904, 0x939}, {0x93D, 0x93D}, {0x950, 0x950}, {0x958, 0x961}, {0x971, 0x980}, {0x985, 0x98C}, {0x98F, 0x990}, {0x993, 0x9A8}, {0x9AA, 0x9B0}, {0x9B2, 0x9B2},
25
+ {0x9B6, 0x9B9}, {0x9BD, 0x9BD}, {0x9CE, 0x9CE}, {0x9DC, 0x9DD}, {0x9DF, 0x9E1}, {0x9F0, 0x9F1}, {0x9FC, 0x9FC}, {0xA05, 0xA0A}, {0xA0F, 0xA10}, {0xA13, 0xA28}, {0xA2A, 0xA30}, {0xA32, 0xA33},
26
+ {0xA35, 0xA36}, {0xA38, 0xA39}, {0xA59, 0xA5C}, {0xA5E, 0xA5E}, {0xA72, 0xA74}, {0xA85, 0xA8D}, {0xA8F, 0xA91}, {0xA93, 0xAA8}, {0xAAA, 0xAB0}, {0xAB2, 0xAB3}, {0xAB5, 0xAB9}, {0xABD, 0xABD},
27
+ {0xAD0, 0xAD0}, {0xAE0, 0xAE1}, {0xAF9, 0xAF9}, {0xB05, 0xB0C}, {0xB0F, 0xB10}, {0xB13, 0xB28}, {0xB2A, 0xB30}, {0xB32, 0xB33}, {0xB35, 0xB39}, {0xB3D, 0xB3D}, {0xB5C, 0xB5D}, {0xB5F, 0xB61},
28
+ {0xB71, 0xB71}, {0xB83, 0xB83}, {0xB85, 0xB8A}, {0xB8E, 0xB90}, {0xB92, 0xB95}, {0xB99, 0xB9A}, {0xB9C, 0xB9C}, {0xB9E, 0xB9F}, {0xBA3, 0xBA4}, {0xBA8, 0xBAA}, {0xBAE, 0xBB9}, {0xBD0, 0xBD0},
29
+ {0xC05, 0xC0C}, {0xC0E, 0xC10}, {0xC12, 0xC28}, {0xC2A, 0xC39}, {0xC3D, 0xC3D}, {0xC58, 0xC5A}, {0xC60, 0xC61}, {0xC80, 0xC80}, {0xC85, 0xC8C}, {0xC8E, 0xC90}, {0xC92, 0xCA8}, {0xCAA, 0xCB3},
30
+ {0xCB5, 0xCB9}, {0xCBD, 0xCBD}, {0xCDE, 0xCDE}, {0xCE0, 0xCE1}, {0xCF1, 0xCF2}, {0xD04, 0xD0C}, {0xD0E, 0xD10}, {0xD12, 0xD3A}, {0xD3D, 0xD3D}, {0xD4E, 0xD4E}, {0xD54, 0xD56}, {0xD5F, 0xD61},
31
+ {0xD7A, 0xD7F}, {0xD85, 0xD96}, {0xD9A, 0xDB1}, {0xDB3, 0xDBB}, {0xDBD, 0xDBD}, {0xDC0, 0xDC6}, {0xE01, 0xE30}, {0xE32, 0xE33}, {0xE40, 0xE46}, {0xE81, 0xE82}, {0xE84, 0xE84}, {0xE86, 0xE8A},
32
+ {0xE8C, 0xEA3}, {0xEA5, 0xEA5}, {0xEA7, 0xEB0}, {0xEB2, 0xEB3}, {0xEBD, 0xEBD}, {0xEC0, 0xEC4}, {0xEC6, 0xEC6}, {0xEDC, 0xEDF}, {0xF00, 0xF00}, {0xF40, 0xF47}, {0xF49, 0xF6C}, {0xF88, 0xF8C},
33
+ {0x1000, 0x102A}, {0x103F, 0x103F}, {0x1050, 0x1055}, {0x105A, 0x105D}, {0x1061, 0x1061}, {0x1065, 0x1066}, {0x106E, 0x1070}, {0x1075, 0x1081}, {0x108E, 0x108E}, {0x10A0, 0x10C5}, {0x10C7, 0x10C7},
34
+ {0x10CD, 0x10CD}, {0x10D0, 0x10FA}, {0x10FC, 0x1248}, {0x124A, 0x124D}, {0x1250, 0x1256}, {0x1258, 0x1258}, {0x125A, 0x125D}, {0x1260, 0x1288}, {0x128A, 0x128D}, {0x1290, 0x12B0}, {0x12B2, 0x12B5},
35
+ {0x12B8, 0x12BE}, {0x12C0, 0x12C0}, {0x12C2, 0x12C5}, {0x12C8, 0x12D6}, {0x12D8, 0x1310}, {0x1312, 0x1315}, {0x1318, 0x135A}, {0x1380, 0x138F}, {0x13A0, 0x13F5}, {0x13F8, 0x13FD}, {0x1401, 0x166C},
36
+ {0x166F, 0x167F}, {0x1681, 0x169A}, {0x16A0, 0x16EA}, {0x16F1, 0x16F8}, {0x1700, 0x170C}, {0x170E, 0x1711}, {0x1720, 0x1731}, {0x1740, 0x1751}, {0x1760, 0x176C}, {0x176E, 0x1770}, {0x1780, 0x17B3},
37
+ {0x17D7, 0x17D7}, {0x17DC, 0x17DC}, {0x1820, 0x1878}, {0x1880, 0x1884}, {0x1887, 0x18A8}, {0x18AA, 0x18AA}, {0x18B0, 0x18F5}, {0x1900, 0x191E}, {0x1950, 0x196D}, {0x1970, 0x1974}, {0x1980, 0x19AB},
38
+ {0x19B0, 0x19C9}, {0x1A00, 0x1A16}, {0x1A20, 0x1A54}, {0x1AA7, 0x1AA7}, {0x1B05, 0x1B33}, {0x1B45, 0x1B4B}, {0x1B83, 0x1BA0}, {0x1BAE, 0x1BAF}, {0x1BBA, 0x1BE5}, {0x1C00, 0x1C23}, {0x1C4D, 0x1C4F},
39
+ {0x1C5A, 0x1C7D}, {0x1C80, 0x1C88}, {0x1C90, 0x1CBA}, {0x1CBD, 0x1CBF}, {0x1CE9, 0x1CEC}, {0x1CEE, 0x1CF3}, {0x1CF5, 0x1CF6}, {0x1CFA, 0x1CFA}, {0x1D00, 0x1DBF}, {0x1E00, 0x1F15}, {0x1F18, 0x1F1D},
40
+ {0x1F20, 0x1F45}, {0x1F48, 0x1F4D}, {0x1F50, 0x1F57}, {0x1F59, 0x1F59}, {0x1F5B, 0x1F5B}, {0x1F5D, 0x1F5D}, {0x1F5F, 0x1F7D}, {0x1F80, 0x1FB4}, {0x1FB6, 0x1FBC}, {0x1FBE, 0x1FBE}, {0x1FC2, 0x1FC4},
41
+ {0x1FC6, 0x1FCC}, {0x1FD0, 0x1FD3}, {0x1FD6, 0x1FDB}, {0x1FE0, 0x1FEC}, {0x1FF2, 0x1FF4}, {0x1FF6, 0x1FFC}, {0x2071, 0x2071}, {0x207F, 0x207F}, {0x2090, 0x209C}, {0x2102, 0x2102}, {0x2107, 0x2107},
42
+ {0x210A, 0x2113}, {0x2115, 0x2115}, {0x2119, 0x211D}, {0x2124, 0x2124}, {0x2126, 0x2126}, {0x2128, 0x2128}, {0x212A, 0x212D}, {0x212F, 0x2139}, {0x213C, 0x213F}, {0x2145, 0x2149}, {0x214E, 0x214E},
43
+ {0x2183, 0x2184}, {0x2C00, 0x2C2E}, {0x2C30, 0x2C5E}, {0x2C60, 0x2CE4}, {0x2CEB, 0x2CEE}, {0x2CF2, 0x2CF3}, {0x2D00, 0x2D25}, {0x2D27, 0x2D27}, {0x2D2D, 0x2D2D}, {0x2D30, 0x2D67}, {0x2D6F, 0x2D6F},
44
+ {0x2D80, 0x2D96}, {0x2DA0, 0x2DA6}, {0x2DA8, 0x2DAE}, {0x2DB0, 0x2DB6}, {0x2DB8, 0x2DBE}, {0x2DC0, 0x2DC6}, {0x2DC8, 0x2DCE}, {0x2DD0, 0x2DD6}, {0x2DD8, 0x2DDE}, {0x2E2F, 0x2E2F}, {0x3005, 0x3006},
45
+ {0x3031, 0x3035}, {0x303B, 0x303C}, {0x3041, 0x3096}, {0x309D, 0x309F}, {0x30A1, 0x30FA}, {0x30FC, 0x30FF}, {0x3105, 0x312F}, {0x3131, 0x318E}, {0x31A0, 0x31BF}, {0x31F0, 0x31FF}, {0x3400, 0x4DBF},
46
+ {0x4E00, 0x9FFC}, {0xA000, 0xA48C}, {0xA4D0, 0xA4FD}, {0xA500, 0xA60C}, {0xA610, 0xA61F}, {0xA62A, 0xA62B}, {0xA640, 0xA66E}, {0xA67F, 0xA69D}, {0xA6A0, 0xA6E5}, {0xA717, 0xA71F}, {0xA722, 0xA788},
47
+ {0xA78B, 0xA7BF}, {0xA7C2, 0xA7CA}, {0xA7F5, 0xA801}, {0xA803, 0xA805}, {0xA807, 0xA80A}, {0xA80C, 0xA822}, {0xA840, 0xA873}, {0xA882, 0xA8B3}, {0xA8F2, 0xA8F7}, {0xA8FB, 0xA8FB}, {0xA8FD, 0xA8FE},
48
+ {0xA90A, 0xA925}, {0xA930, 0xA946}, {0xA960, 0xA97C}, {0xA984, 0xA9B2}, {0xA9CF, 0xA9CF}, {0xA9E0, 0xA9E4}, {0xA9E6, 0xA9EF}, {0xA9FA, 0xA9FE}, {0xAA00, 0xAA28}, {0xAA40, 0xAA42}, {0xAA44, 0xAA4B},
49
+ {0xAA60, 0xAA76}, {0xAA7A, 0xAA7A}, {0xAA7E, 0xAAAF}, {0xAAB1, 0xAAB1}, {0xAAB5, 0xAAB6}, {0xAAB9, 0xAABD}, {0xAAC0, 0xAAC0}, {0xAAC2, 0xAAC2}, {0xAADB, 0xAADD}, {0xAAE0, 0xAAEA}, {0xAAF2, 0xAAF4},
50
+ {0xAB01, 0xAB06}, {0xAB09, 0xAB0E}, {0xAB11, 0xAB16}, {0xAB20, 0xAB26}, {0xAB28, 0xAB2E}, {0xAB30, 0xAB5A}, {0xAB5C, 0xAB69}, {0xAB70, 0xABE2}, {0xAC00, 0xD7A3}, {0xD7B0, 0xD7C6}, {0xD7CB, 0xD7FB},
51
+ {0xF900, 0xFA6D}, {0xFA70, 0xFAD9}, {0xFB00, 0xFB06}, {0xFB13, 0xFB17}, {0xFB1D, 0xFB1D}, {0xFB1F, 0xFB28}, {0xFB2A, 0xFB36}, {0xFB38, 0xFB3C}, {0xFB3E, 0xFB3E}, {0xFB40, 0xFB41}, {0xFB43, 0xFB44},
52
+ {0xFB46, 0xFBB1}, {0xFBD3, 0xFD3D}, {0xFD50, 0xFD8F}, {0xFD92, 0xFDC7}, {0xFDF0, 0xFDFB}, {0xFE70, 0xFE74}, {0xFE76, 0xFEFC}, {0xFF21, 0xFF3A}, {0xFF41, 0xFF5A}, {0xFF66, 0xFFBE}, {0xFFC2, 0xFFC7},
53
+ {0xFFCA, 0xFFCF}, {0xFFD2, 0xFFD7}, {0xFFDA, 0xFFDC}, {0x10000, 0x1000B}, {0x1000D, 0x10026}, {0x10028, 0x1003A}, {0x1003C, 0x1003D}, {0x1003F, 0x1004D}, {0x10050, 0x1005D}, {0x10080, 0x100FA},
54
+ {0x10280, 0x1029C}, {0x102A0, 0x102D0}, {0x10300, 0x1031F}, {0x1032D, 0x10340}, {0x10342, 0x10349}, {0x10350, 0x10375}, {0x10380, 0x1039D}, {0x103A0, 0x103C3}, {0x103C8, 0x103CF}, {0x10400, 0x1049D},
55
+ {0x104B0, 0x104D3}, {0x104D8, 0x104FB}, {0x10500, 0x10527}, {0x10530, 0x10563}, {0x10600, 0x10736}, {0x10740, 0x10755}, {0x10760, 0x10767}, {0x10800, 0x10805}, {0x10808, 0x10808}, {0x1080A, 0x10835},
56
+ {0x10837, 0x10838}, {0x1083C, 0x1083C}, {0x1083F, 0x10855}, {0x10860, 0x10876}, {0x10880, 0x1089E}, {0x108E0, 0x108F2}, {0x108F4, 0x108F5}, {0x10900, 0x10915}, {0x10920, 0x10939}, {0x10980, 0x109B7},
57
+ {0x109BE, 0x109BF}, {0x10A00, 0x10A00}, {0x10A10, 0x10A13}, {0x10A15, 0x10A17}, {0x10A19, 0x10A35}, {0x10A60, 0x10A7C}, {0x10A80, 0x10A9C}, {0x10AC0, 0x10AC7}, {0x10AC9, 0x10AE4}, {0x10B00, 0x10B35},
58
+ {0x10B40, 0x10B55}, {0x10B60, 0x10B72}, {0x10B80, 0x10B91}, {0x10C00, 0x10C48}, {0x10C80, 0x10CB2}, {0x10CC0, 0x10CF2}, {0x10D00, 0x10D23}, {0x10E80, 0x10EA9}, {0x10EB0, 0x10EB1}, {0x10F00, 0x10F1C},
59
+ {0x10F27, 0x10F27}, {0x10F30, 0x10F45}, {0x10FB0, 0x10FC4}, {0x10FE0, 0x10FF6}, {0x11003, 0x11037}, {0x11083, 0x110AF}, {0x110D0, 0x110E8}, {0x11103, 0x11126}, {0x11144, 0x11144}, {0x11147, 0x11147},
60
+ {0x11150, 0x11172}, {0x11176, 0x11176}, {0x11183, 0x111B2}, {0x111C1, 0x111C4}, {0x111DA, 0x111DA}, {0x111DC, 0x111DC}, {0x11200, 0x11211}, {0x11213, 0x1122B}, {0x11280, 0x11286}, {0x11288, 0x11288},
61
+ {0x1128A, 0x1128D}, {0x1128F, 0x1129D}, {0x1129F, 0x112A8}, {0x112B0, 0x112DE}, {0x11305, 0x1130C}, {0x1130F, 0x11310}, {0x11313, 0x11328}, {0x1132A, 0x11330}, {0x11332, 0x11333}, {0x11335, 0x11339},
62
+ {0x1133D, 0x1133D}, {0x11350, 0x11350}, {0x1135D, 0x11361}, {0x11400, 0x11434}, {0x11447, 0x1144A}, {0x1145F, 0x11461}, {0x11480, 0x114AF}, {0x114C4, 0x114C5}, {0x114C7, 0x114C7}, {0x11580, 0x115AE},
63
+ {0x115D8, 0x115DB}, {0x11600, 0x1162F}, {0x11644, 0x11644}, {0x11680, 0x116AA}, {0x116B8, 0x116B8}, {0x11700, 0x1171A}, {0x11800, 0x1182B}, {0x118A0, 0x118DF}, {0x118FF, 0x11906}, {0x11909, 0x11909},
64
+ {0x1190C, 0x11913}, {0x11915, 0x11916}, {0x11918, 0x1192F}, {0x1193F, 0x1193F}, {0x11941, 0x11941}, {0x119A0, 0x119A7}, {0x119AA, 0x119D0}, {0x119E1, 0x119E1}, {0x119E3, 0x119E3}, {0x11A00, 0x11A00},
65
+ {0x11A0B, 0x11A32}, {0x11A3A, 0x11A3A}, {0x11A50, 0x11A50}, {0x11A5C, 0x11A89}, {0x11A9D, 0x11A9D}, {0x11AC0, 0x11AF8}, {0x11C00, 0x11C08}, {0x11C0A, 0x11C2E}, {0x11C40, 0x11C40}, {0x11C72, 0x11C8F},
66
+ {0x11D00, 0x11D06}, {0x11D08, 0x11D09}, {0x11D0B, 0x11D30}, {0x11D46, 0x11D46}, {0x11D60, 0x11D65}, {0x11D67, 0x11D68}, {0x11D6A, 0x11D89}, {0x11D98, 0x11D98}, {0x11EE0, 0x11EF2}, {0x11FB0, 0x11FB0},
67
+ {0x12000, 0x12399}, {0x12480, 0x12543}, {0x13000, 0x1342E}, {0x14400, 0x14646}, {0x16800, 0x16A38}, {0x16A40, 0x16A5E}, {0x16AD0, 0x16AED}, {0x16B00, 0x16B2F}, {0x16B40, 0x16B43}, {0x16B63, 0x16B77},
68
+ {0x16B7D, 0x16B8F}, {0x16E40, 0x16E7F}, {0x16F00, 0x16F4A}, {0x16F50, 0x16F50}, {0x16F93, 0x16F9F}, {0x16FE0, 0x16FE1}, {0x16FE3, 0x16FE3}, {0x17000, 0x187F7}, {0x18800, 0x18CD5}, {0x18D00, 0x18D08},
69
+ {0x1B000, 0x1B11E}, {0x1B150, 0x1B152}, {0x1B164, 0x1B167}, {0x1B170, 0x1B2FB}, {0x1BC00, 0x1BC6A}, {0x1BC70, 0x1BC7C}, {0x1BC80, 0x1BC88}, {0x1BC90, 0x1BC99}, {0x1D400, 0x1D454}, {0x1D456, 0x1D49C},
70
+ {0x1D49E, 0x1D49F}, {0x1D4A2, 0x1D4A2}, {0x1D4A5, 0x1D4A6}, {0x1D4A9, 0x1D4AC}, {0x1D4AE, 0x1D4B9}, {0x1D4BB, 0x1D4BB}, {0x1D4BD, 0x1D4C3}, {0x1D4C5, 0x1D505}, {0x1D507, 0x1D50A}, {0x1D50D, 0x1D514},
71
+ {0x1D516, 0x1D51C}, {0x1D51E, 0x1D539}, {0x1D53B, 0x1D53E}, {0x1D540, 0x1D544}, {0x1D546, 0x1D546}, {0x1D54A, 0x1D550}, {0x1D552, 0x1D6A5}, {0x1D6A8, 0x1D6C0}, {0x1D6C2, 0x1D6DA}, {0x1D6DC, 0x1D6FA},
72
+ {0x1D6FC, 0x1D714}, {0x1D716, 0x1D734}, {0x1D736, 0x1D74E}, {0x1D750, 0x1D76E}, {0x1D770, 0x1D788}, {0x1D78A, 0x1D7A8}, {0x1D7AA, 0x1D7C2}, {0x1D7C4, 0x1D7CB}, {0x1E100, 0x1E12C}, {0x1E137, 0x1E13D},
73
+ {0x1E14E, 0x1E14E}, {0x1E2C0, 0x1E2EB}, {0x1E800, 0x1E8C4}, {0x1E900, 0x1E943}, {0x1E94B, 0x1E94B}, {0x1EE00, 0x1EE03}, {0x1EE05, 0x1EE1F}, {0x1EE21, 0x1EE22}, {0x1EE24, 0x1EE24}, {0x1EE27, 0x1EE27},
74
+ {0x1EE29, 0x1EE32}, {0x1EE34, 0x1EE37}, {0x1EE39, 0x1EE39}, {0x1EE3B, 0x1EE3B}, {0x1EE42, 0x1EE42}, {0x1EE47, 0x1EE47}, {0x1EE49, 0x1EE49}, {0x1EE4B, 0x1EE4B}, {0x1EE4D, 0x1EE4F}, {0x1EE51, 0x1EE52},
75
+ {0x1EE54, 0x1EE54}, {0x1EE57, 0x1EE57}, {0x1EE59, 0x1EE59}, {0x1EE5B, 0x1EE5B}, {0x1EE5D, 0x1EE5D}, {0x1EE5F, 0x1EE5F}, {0x1EE61, 0x1EE62}, {0x1EE64, 0x1EE64}, {0x1EE67, 0x1EE6A}, {0x1EE6C, 0x1EE72},
76
+ {0x1EE74, 0x1EE77}, {0x1EE79, 0x1EE7C}, {0x1EE7E, 0x1EE7E}, {0x1EE80, 0x1EE89}, {0x1EE8B, 0x1EE9B}, {0x1EEA1, 0x1EEA3}, {0x1EEA5, 0x1EEA9}, {0x1EEAB, 0x1EEBB}, {0x20000, 0x2A6DD}, {0x2A700, 0x2B734},
77
+ {0x2B740, 0x2B81D}, {0x2B820, 0x2CEA1}, {0x2CEB0, 0x2EBE0}, {0x2F800, 0x2FA1D}, {0x30000, 0x3134A},
78
+ };
79
+
80
+ static const std::vector<std::pair<uint32_t, uint32_t>> whitespace_ranges = {
81
+ {0x9, 0xD}, {0x1C, 0x20}, {0x85, 0x85}, {0xA0, 0xA0}, {0x1680, 0x1680}, {0x2000, 0x200A}, {0x2028, 0x2029}, {0x202F, 0x202F}, {0x205F, 0x205F}, {0x3000, 0x3000},
82
+ };
83
+
84
+ static const std::vector<std::pair<uint32_t, uint32_t>> accent_mark_ranges = {
85
+ {0x300, 0x36F}, {0x483, 0x489}, {0x591, 0x5BD}, {0x5BF, 0x5BF}, {0x5C1, 0x5C2}, {0x5C4, 0x5C5}, {0x5C7, 0x5C7}, {0x610, 0x61A}, {0x64B, 0x65F}, {0x670, 0x670}, {0x6D6, 0x6DC}, {0x6DF, 0x6E4},
86
+ {0x6E7, 0x6E8}, {0x6EA, 0x6ED}, {0x711, 0x711}, {0x730, 0x74A}, {0x7A6, 0x7B0}, {0x7EB, 0x7F3}, {0x7FD, 0x7FD}, {0x816, 0x819}, {0x81B, 0x823}, {0x825, 0x827}, {0x829, 0x82D}, {0x859, 0x85B},
87
+ {0x8D3, 0x8E1}, {0x8E3, 0x903}, {0x93A, 0x93C}, {0x93E, 0x94F}, {0x951, 0x957}, {0x962, 0x963}, {0x981, 0x983}, {0x9BC, 0x9BC}, {0x9BE, 0x9C4}, {0x9C7, 0x9C8}, {0x9CB, 0x9CD}, {0x9D7, 0x9D7},
88
+ {0x9E2, 0x9E3}, {0x9FE, 0x9FE}, {0xA01, 0xA03}, {0xA3C, 0xA3C}, {0xA3E, 0xA42}, {0xA47, 0xA48}, {0xA4B, 0xA4D}, {0xA51, 0xA51}, {0xA70, 0xA71}, {0xA75, 0xA75}, {0xA81, 0xA83}, {0xABC, 0xABC},
89
+ {0xABE, 0xAC5}, {0xAC7, 0xAC9}, {0xACB, 0xACD}, {0xAE2, 0xAE3}, {0xAFA, 0xAFF}, {0xB01, 0xB03}, {0xB3C, 0xB3C}, {0xB3E, 0xB44}, {0xB47, 0xB48}, {0xB4B, 0xB4D}, {0xB55, 0xB57}, {0xB62, 0xB63},
90
+ {0xB82, 0xB82}, {0xBBE, 0xBC2}, {0xBC6, 0xBC8}, {0xBCA, 0xBCD}, {0xBD7, 0xBD7}, {0xC00, 0xC04}, {0xC3E, 0xC44}, {0xC46, 0xC48}, {0xC4A, 0xC4D}, {0xC55, 0xC56}, {0xC62, 0xC63}, {0xC81, 0xC83},
91
+ {0xCBC, 0xCBC}, {0xCBE, 0xCC4}, {0xCC6, 0xCC8}, {0xCCA, 0xCCD}, {0xCD5, 0xCD6}, {0xCE2, 0xCE3}, {0xD00, 0xD03}, {0xD3B, 0xD3C}, {0xD3E, 0xD44}, {0xD46, 0xD48}, {0xD4A, 0xD4D}, {0xD57, 0xD57},
92
+ {0xD62, 0xD63}, {0xD81, 0xD83}, {0xDCA, 0xDCA}, {0xDCF, 0xDD4}, {0xDD6, 0xDD6}, {0xDD8, 0xDDF}, {0xDF2, 0xDF3}, {0xE31, 0xE31}, {0xE34, 0xE3A}, {0xE47, 0xE4E}, {0xEB1, 0xEB1}, {0xEB4, 0xEBC},
93
+ {0xEC8, 0xECD}, {0xF18, 0xF19}, {0xF35, 0xF35}, {0xF37, 0xF37}, {0xF39, 0xF39}, {0xF3E, 0xF3F}, {0xF71, 0xF84}, {0xF86, 0xF87}, {0xF8D, 0xF97}, {0xF99, 0xFBC}, {0xFC6, 0xFC6}, {0x102B, 0x103E},
94
+ {0x1056, 0x1059}, {0x105E, 0x1060}, {0x1062, 0x1064}, {0x1067, 0x106D}, {0x1071, 0x1074}, {0x1082, 0x108D}, {0x108F, 0x108F}, {0x109A, 0x109D}, {0x135D, 0x135F}, {0x1712, 0x1714}, {0x1732, 0x1734},
95
+ {0x1752, 0x1753}, {0x1772, 0x1773}, {0x17B4, 0x17D3}, {0x17DD, 0x17DD}, {0x180B, 0x180D}, {0x1885, 0x1886}, {0x18A9, 0x18A9}, {0x1920, 0x192B}, {0x1930, 0x193B}, {0x1A17, 0x1A1B}, {0x1A55, 0x1A5E},
96
+ {0x1A60, 0x1A7C}, {0x1A7F, 0x1A7F}, {0x1AB0, 0x1AC0}, {0x1B00, 0x1B04}, {0x1B34, 0x1B44}, {0x1B6B, 0x1B73}, {0x1B80, 0x1B82}, {0x1BA1, 0x1BAD}, {0x1BE6, 0x1BF3}, {0x1C24, 0x1C37}, {0x1CD0, 0x1CD2},
97
+ {0x1CD4, 0x1CE8}, {0x1CED, 0x1CED}, {0x1CF4, 0x1CF4}, {0x1CF7, 0x1CF9}, {0x1DC0, 0x1DF9}, {0x1DFB, 0x1DFF}, {0x20D0, 0x20F0}, {0x2CEF, 0x2CF1}, {0x2D7F, 0x2D7F}, {0x2DE0, 0x2DFF}, {0x302A, 0x302F},
98
+ {0x3099, 0x309A}, {0xA66F, 0xA672}, {0xA674, 0xA67D}, {0xA69E, 0xA69F}, {0xA6F0, 0xA6F1}, {0xA802, 0xA802}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, {0xA823, 0xA827}, {0xA82C, 0xA82C}, {0xA880, 0xA881},
99
+ {0xA8B4, 0xA8C5}, {0xA8E0, 0xA8F1}, {0xA8FF, 0xA8FF}, {0xA926, 0xA92D}, {0xA947, 0xA953}, {0xA980, 0xA983}, {0xA9B3, 0xA9C0}, {0xA9E5, 0xA9E5}, {0xAA29, 0xAA36}, {0xAA43, 0xAA43}, {0xAA4C, 0xAA4D},
100
+ {0xAA7B, 0xAA7D}, {0xAAB0, 0xAAB0}, {0xAAB2, 0xAAB4}, {0xAAB7, 0xAAB8}, {0xAABE, 0xAABF}, {0xAAC1, 0xAAC1}, {0xAAEB, 0xAAEF}, {0xAAF5, 0xAAF6}, {0xABE3, 0xABEA}, {0xABEC, 0xABED}, {0xFB1E, 0xFB1E},
101
+ {0xFE00, 0xFE0F}, {0xFE20, 0xFE2F}, {0x101FD, 0x101FD}, {0x102E0, 0x102E0}, {0x10376, 0x1037A}, {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F},
102
+ {0x10AE5, 0x10AE6}, {0x10D24, 0x10D27}, {0x10EAB, 0x10EAC}, {0x10F46, 0x10F50}, {0x11000, 0x11002}, {0x11038, 0x11046}, {0x1107F, 0x11082}, {0x110B0, 0x110BA}, {0x11100, 0x11102}, {0x11127, 0x11134},
103
+ {0x11145, 0x11146}, {0x11173, 0x11173}, {0x11180, 0x11182}, {0x111B3, 0x111C0}, {0x111C9, 0x111CC}, {0x111CE, 0x111CF}, {0x1122C, 0x11237}, {0x1123E, 0x1123E}, {0x112DF, 0x112EA}, {0x11300, 0x11303},
104
+ {0x1133B, 0x1133C}, {0x1133E, 0x11344}, {0x11347, 0x11348}, {0x1134B, 0x1134D}, {0x11357, 0x11357}, {0x11362, 0x11363}, {0x11366, 0x1136C}, {0x11370, 0x11374}, {0x11435, 0x11446}, {0x1145E, 0x1145E},
105
+ {0x114B0, 0x114C3}, {0x115AF, 0x115B5}, {0x115B8, 0x115C0}, {0x115DC, 0x115DD}, {0x11630, 0x11640}, {0x116AB, 0x116B7}, {0x1171D, 0x1172B}, {0x1182C, 0x1183A}, {0x11930, 0x11935}, {0x11937, 0x11938},
106
+ {0x1193B, 0x1193E}, {0x11940, 0x11940}, {0x11942, 0x11943}, {0x119D1, 0x119D7}, {0x119DA, 0x119E0}, {0x119E4, 0x119E4}, {0x11A01, 0x11A0A}, {0x11A33, 0x11A39}, {0x11A3B, 0x11A3E}, {0x11A47, 0x11A47},
107
+ {0x11A51, 0x11A5B}, {0x11A8A, 0x11A99}, {0x11C2F, 0x11C36}, {0x11C38, 0x11C3F}, {0x11C92, 0x11CA7}, {0x11CA9, 0x11CB6}, {0x11D31, 0x11D36}, {0x11D3A, 0x11D3A}, {0x11D3C, 0x11D3D}, {0x11D3F, 0x11D45},
108
+ {0x11D47, 0x11D47}, {0x11D8A, 0x11D8E}, {0x11D90, 0x11D91}, {0x11D93, 0x11D97}, {0x11EF3, 0x11EF6}, {0x16AF0, 0x16AF4}, {0x16B30, 0x16B36}, {0x16F4F, 0x16F4F}, {0x16F51, 0x16F87}, {0x16F8F, 0x16F92},
109
+ {0x16FE4, 0x16FE4}, {0x16FF0, 0x16FF1}, {0x1BC9D, 0x1BC9E}, {0x1D165, 0x1D169}, {0x1D16D, 0x1D172}, {0x1D17B, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, {0x1D242, 0x1D244}, {0x1DA00, 0x1DA36},
110
+ {0x1DA3B, 0x1DA6C}, {0x1DA75, 0x1DA75}, {0x1DA84, 0x1DA84}, {0x1DA9B, 0x1DA9F}, {0x1DAA1, 0x1DAAF}, {0x1E000, 0x1E006}, {0x1E008, 0x1E018}, {0x1E01B, 0x1E021}, {0x1E023, 0x1E024}, {0x1E026, 0x1E02A},
111
+ {0x1E130, 0x1E136}, {0x1E2EC, 0x1E2EF}, {0x1E8D0, 0x1E8D6}, {0x1E944, 0x1E94A}, {0xE0100, 0xE01EF},
112
+ };
113
+
114
+ static const std::vector<std::pair<uint32_t, uint32_t>> punctuation_ranges = {
115
+ {0x21, 0x23}, {0x25, 0x2A}, {0x2C, 0x2F}, {0x3A, 0x3B}, {0x3F, 0x40}, {0x5B, 0x5D}, {0x5F, 0x5F}, {0x7B, 0x7B}, {0x7D, 0x7D}, {0xA1, 0xA1}, {0xA7, 0xA7}, {0xAB, 0xAB}, {0xB6, 0xB7}, {0xBB, 0xBB},
116
+ {0xBF, 0xBF}, {0x37E, 0x37E}, {0x387, 0x387}, {0x55A, 0x55F}, {0x589, 0x58A}, {0x5BE, 0x5BE}, {0x5C0, 0x5C0}, {0x5C3, 0x5C3}, {0x5C6, 0x5C6}, {0x5F3, 0x5F4}, {0x609, 0x60A}, {0x60C, 0x60D},
117
+ {0x61B, 0x61B}, {0x61E, 0x61F}, {0x66A, 0x66D}, {0x6D4, 0x6D4}, {0x700, 0x70D}, {0x7F7, 0x7F9}, {0x830, 0x83E}, {0x85E, 0x85E}, {0x964, 0x965}, {0x970, 0x970}, {0x9FD, 0x9FD}, {0xA76, 0xA76},
118
+ {0xAF0, 0xAF0}, {0xC77, 0xC77}, {0xC84, 0xC84}, {0xDF4, 0xDF4}, {0xE4F, 0xE4F}, {0xE5A, 0xE5B}, {0xF04, 0xF12}, {0xF14, 0xF14}, {0xF3A, 0xF3D}, {0xF85, 0xF85}, {0xFD0, 0xFD4}, {0xFD9, 0xFDA},
119
+ {0x104A, 0x104F}, {0x10FB, 0x10FB}, {0x1360, 0x1368}, {0x1400, 0x1400}, {0x166E, 0x166E}, {0x169B, 0x169C}, {0x16EB, 0x16ED}, {0x1735, 0x1736}, {0x17D4, 0x17D6}, {0x17D8, 0x17DA}, {0x1800, 0x180A},
120
+ {0x1944, 0x1945}, {0x1A1E, 0x1A1F}, {0x1AA0, 0x1AA6}, {0x1AA8, 0x1AAD}, {0x1B5A, 0x1B60}, {0x1BFC, 0x1BFF}, {0x1C3B, 0x1C3F}, {0x1C7E, 0x1C7F}, {0x1CC0, 0x1CC7}, {0x1CD3, 0x1CD3}, {0x2010, 0x2027},
121
+ {0x2030, 0x2043}, {0x2045, 0x2051}, {0x2053, 0x205E}, {0x207D, 0x207E}, {0x208D, 0x208E}, {0x2308, 0x230B}, {0x2329, 0x232A}, {0x2768, 0x2775}, {0x27C5, 0x27C6}, {0x27E6, 0x27EF}, {0x2983, 0x2998},
122
+ {0x29D8, 0x29DB}, {0x29FC, 0x29FD}, {0x2CF9, 0x2CFC}, {0x2CFE, 0x2CFF}, {0x2D70, 0x2D70}, {0x2E00, 0x2E2E}, {0x2E30, 0x2E4F}, {0x2E52, 0x2E52}, {0x3001, 0x3003}, {0x3008, 0x3011}, {0x3014, 0x301F},
123
+ {0x3030, 0x3030}, {0x303D, 0x303D}, {0x30A0, 0x30A0}, {0x30FB, 0x30FB}, {0xA4FE, 0xA4FF}, {0xA60D, 0xA60F}, {0xA673, 0xA673}, {0xA67E, 0xA67E}, {0xA6F2, 0xA6F7}, {0xA874, 0xA877}, {0xA8CE, 0xA8CF},
124
+ {0xA8F8, 0xA8FA}, {0xA8FC, 0xA8FC}, {0xA92E, 0xA92F}, {0xA95F, 0xA95F}, {0xA9C1, 0xA9CD}, {0xA9DE, 0xA9DF}, {0xAA5C, 0xAA5F}, {0xAADE, 0xAADF}, {0xAAF0, 0xAAF1}, {0xABEB, 0xABEB}, {0xFD3E, 0xFD3F},
125
+ {0xFE10, 0xFE19}, {0xFE30, 0xFE52}, {0xFE54, 0xFE61}, {0xFE63, 0xFE63}, {0xFE68, 0xFE68}, {0xFE6A, 0xFE6B}, {0xFF01, 0xFF03}, {0xFF05, 0xFF0A}, {0xFF0C, 0xFF0F}, {0xFF1A, 0xFF1B}, {0xFF1F, 0xFF20},
126
+ {0xFF3B, 0xFF3D}, {0xFF3F, 0xFF3F}, {0xFF5B, 0xFF5B}, {0xFF5D, 0xFF5D}, {0xFF5F, 0xFF65}, {0x10100, 0x10102}, {0x1039F, 0x1039F}, {0x103D0, 0x103D0}, {0x1056F, 0x1056F}, {0x10857, 0x10857},
127
+ {0x1091F, 0x1091F}, {0x1093F, 0x1093F}, {0x10A50, 0x10A58}, {0x10A7F, 0x10A7F}, {0x10AF0, 0x10AF6}, {0x10B39, 0x10B3F}, {0x10B99, 0x10B9C}, {0x10EAD, 0x10EAD}, {0x10F55, 0x10F59}, {0x11047, 0x1104D},
128
+ {0x110BB, 0x110BC}, {0x110BE, 0x110C1}, {0x11140, 0x11143}, {0x11174, 0x11175}, {0x111C5, 0x111C8}, {0x111CD, 0x111CD}, {0x111DB, 0x111DB}, {0x111DD, 0x111DF}, {0x11238, 0x1123D}, {0x112A9, 0x112A9},
129
+ {0x1144B, 0x1144F}, {0x1145A, 0x1145B}, {0x1145D, 0x1145D}, {0x114C6, 0x114C6}, {0x115C1, 0x115D7}, {0x11641, 0x11643}, {0x11660, 0x1166C}, {0x1173C, 0x1173E}, {0x1183B, 0x1183B}, {0x11944, 0x11946},
130
+ {0x119E2, 0x119E2}, {0x11A3F, 0x11A46}, {0x11A9A, 0x11A9C}, {0x11A9E, 0x11AA2}, {0x11C41, 0x11C45}, {0x11C70, 0x11C71}, {0x11EF7, 0x11EF8}, {0x11FFF, 0x11FFF}, {0x12470, 0x12474}, {0x16A6E, 0x16A6F},
131
+ {0x16AF5, 0x16AF5}, {0x16B37, 0x16B3B}, {0x16B44, 0x16B44}, {0x16E97, 0x16E9A}, {0x16FE2, 0x16FE2}, {0x1BC9F, 0x1BC9F}, {0x1DA87, 0x1DA8B}, {0x1E95E, 0x1E95F},
132
+ };
133
+
134
+ static const std::vector<std::pair<uint32_t, uint32_t>> symbol_ranges = {
135
+ {0x24, 0x24}, {0x2B, 0x2B}, {0x3C, 0x3E}, {0x5E, 0x5E}, {0x60, 0x60}, {0x7C, 0x7C}, {0x7E, 0x7E}, {0xA2, 0xA6}, {0xA8, 0xA9}, {0xAC, 0xAC}, {0xAE, 0xB1}, {0xB4, 0xB4}, {0xB8, 0xB8}, {0xD7, 0xD7},
136
+ {0xF7, 0xF7}, {0x2C2, 0x2C5}, {0x2D2, 0x2DF}, {0x2E5, 0x2EB}, {0x2ED, 0x2ED}, {0x2EF, 0x2FF}, {0x375, 0x375}, {0x384, 0x385}, {0x3F6, 0x3F6}, {0x482, 0x482}, {0x58D, 0x58F}, {0x606, 0x608},
137
+ {0x60B, 0x60B}, {0x60E, 0x60F}, {0x6DE, 0x6DE}, {0x6E9, 0x6E9}, {0x6FD, 0x6FE}, {0x7F6, 0x7F6}, {0x7FE, 0x7FF}, {0x9F2, 0x9F3}, {0x9FA, 0x9FB}, {0xAF1, 0xAF1}, {0xB70, 0xB70}, {0xBF3, 0xBFA},
138
+ {0xC7F, 0xC7F}, {0xD4F, 0xD4F}, {0xD79, 0xD79}, {0xE3F, 0xE3F}, {0xF01, 0xF03}, {0xF13, 0xF13}, {0xF15, 0xF17}, {0xF1A, 0xF1F}, {0xF34, 0xF34}, {0xF36, 0xF36}, {0xF38, 0xF38}, {0xFBE, 0xFC5},
139
+ {0xFC7, 0xFCC}, {0xFCE, 0xFCF}, {0xFD5, 0xFD8}, {0x109E, 0x109F}, {0x1390, 0x1399}, {0x166D, 0x166D}, {0x17DB, 0x17DB}, {0x1940, 0x1940}, {0x19DE, 0x19FF}, {0x1B61, 0x1B6A}, {0x1B74, 0x1B7C},
140
+ {0x1FBD, 0x1FBD}, {0x1FBF, 0x1FC1}, {0x1FCD, 0x1FCF}, {0x1FDD, 0x1FDF}, {0x1FED, 0x1FEF}, {0x1FFD, 0x1FFE}, {0x2044, 0x2044}, {0x2052, 0x2052}, {0x207A, 0x207C}, {0x208A, 0x208C}, {0x20A0, 0x20BF},
141
+ {0x2100, 0x2101}, {0x2103, 0x2106}, {0x2108, 0x2109}, {0x2114, 0x2114}, {0x2116, 0x2118}, {0x211E, 0x2123}, {0x2125, 0x2125}, {0x2127, 0x2127}, {0x2129, 0x2129}, {0x212E, 0x212E}, {0x213A, 0x213B},
142
+ {0x2140, 0x2144}, {0x214A, 0x214D}, {0x214F, 0x214F}, {0x218A, 0x218B}, {0x2190, 0x2307}, {0x230C, 0x2328}, {0x232B, 0x2426}, {0x2440, 0x244A}, {0x249C, 0x24E9}, {0x2500, 0x2767}, {0x2794, 0x27C4},
143
+ {0x27C7, 0x27E5}, {0x27F0, 0x2982}, {0x2999, 0x29D7}, {0x29DC, 0x29FB}, {0x29FE, 0x2B73}, {0x2B76, 0x2B95}, {0x2B97, 0x2BFF}, {0x2CE5, 0x2CEA}, {0x2E50, 0x2E51}, {0x2E80, 0x2E99}, {0x2E9B, 0x2EF3},
144
+ {0x2F00, 0x2FD5}, {0x2FF0, 0x2FFB}, {0x3004, 0x3004}, {0x3012, 0x3013}, {0x3020, 0x3020}, {0x3036, 0x3037}, {0x303E, 0x303F}, {0x309B, 0x309C}, {0x3190, 0x3191}, {0x3196, 0x319F}, {0x31C0, 0x31E3},
145
+ {0x3200, 0x321E}, {0x322A, 0x3247}, {0x3250, 0x3250}, {0x3260, 0x327F}, {0x328A, 0x32B0}, {0x32C0, 0x33FF}, {0x4DC0, 0x4DFF}, {0xA490, 0xA4C6}, {0xA700, 0xA716}, {0xA720, 0xA721}, {0xA789, 0xA78A},
146
+ {0xA828, 0xA82B}, {0xA836, 0xA839}, {0xAA77, 0xAA79}, {0xAB5B, 0xAB5B}, {0xAB6A, 0xAB6B}, {0xFB29, 0xFB29}, {0xFBB2, 0xFBC1}, {0xFDFC, 0xFDFD}, {0xFE62, 0xFE62}, {0xFE64, 0xFE66}, {0xFE69, 0xFE69},
147
+ {0xFF04, 0xFF04}, {0xFF0B, 0xFF0B}, {0xFF1C, 0xFF1E}, {0xFF3E, 0xFF3E}, {0xFF40, 0xFF40}, {0xFF5C, 0xFF5C}, {0xFF5E, 0xFF5E}, {0xFFE0, 0xFFE6}, {0xFFE8, 0xFFEE}, {0xFFFC, 0xFFFD}, {0x10137, 0x1013F},
148
+ {0x10179, 0x10189}, {0x1018C, 0x1018E}, {0x10190, 0x1019C}, {0x101A0, 0x101A0}, {0x101D0, 0x101FC}, {0x10877, 0x10878}, {0x10AC8, 0x10AC8}, {0x1173F, 0x1173F}, {0x11FD5, 0x11FF1}, {0x16B3C, 0x16B3F},
149
+ {0x16B45, 0x16B45}, {0x1BC9C, 0x1BC9C}, {0x1D000, 0x1D0F5}, {0x1D100, 0x1D126}, {0x1D129, 0x1D164}, {0x1D16A, 0x1D16C}, {0x1D183, 0x1D184}, {0x1D18C, 0x1D1A9}, {0x1D1AE, 0x1D1E8}, {0x1D200, 0x1D241},
150
+ {0x1D245, 0x1D245}, {0x1D300, 0x1D356}, {0x1D6C1, 0x1D6C1}, {0x1D6DB, 0x1D6DB}, {0x1D6FB, 0x1D6FB}, {0x1D715, 0x1D715}, {0x1D735, 0x1D735}, {0x1D74F, 0x1D74F}, {0x1D76F, 0x1D76F}, {0x1D789, 0x1D789},
151
+ {0x1D7A9, 0x1D7A9}, {0x1D7C3, 0x1D7C3}, {0x1D800, 0x1D9FF}, {0x1DA37, 0x1DA3A}, {0x1DA6D, 0x1DA74}, {0x1DA76, 0x1DA83}, {0x1DA85, 0x1DA86}, {0x1E14F, 0x1E14F}, {0x1E2FF, 0x1E2FF}, {0x1ECAC, 0x1ECAC},
152
+ {0x1ECB0, 0x1ECB0}, {0x1ED2E, 0x1ED2E}, {0x1EEF0, 0x1EEF1}, {0x1F000, 0x1F02B}, {0x1F030, 0x1F093}, {0x1F0A0, 0x1F0AE}, {0x1F0B1, 0x1F0BF}, {0x1F0C1, 0x1F0CF}, {0x1F0D1, 0x1F0F5}, {0x1F10D, 0x1F1AD},
153
+ {0x1F1E6, 0x1F202}, {0x1F210, 0x1F23B}, {0x1F240, 0x1F248}, {0x1F250, 0x1F251}, {0x1F260, 0x1F265}, {0x1F300, 0x1F6D7}, {0x1F6E0, 0x1F6EC}, {0x1F6F0, 0x1F6FC}, {0x1F700, 0x1F773}, {0x1F780, 0x1F7D8},
154
+ {0x1F7E0, 0x1F7EB}, {0x1F800, 0x1F80B}, {0x1F810, 0x1F847}, {0x1F850, 0x1F859}, {0x1F860, 0x1F887}, {0x1F890, 0x1F8AD}, {0x1F8B0, 0x1F8B1}, {0x1F900, 0x1F978}, {0x1F97A, 0x1F9CB}, {0x1F9CD, 0x1FA53},
155
+ {0x1FA60, 0x1FA6D}, {0x1FA70, 0x1FA74}, {0x1FA78, 0x1FA7A}, {0x1FA80, 0x1FA86}, {0x1FA90, 0x1FAA8}, {0x1FAB0, 0x1FAB6}, {0x1FAC0, 0x1FAC2}, {0x1FAD0, 0x1FAD6}, {0x1FB00, 0x1FB92}, {0x1FB94, 0x1FBCA},
156
+ };
157
+
158
+ static const std::vector<std::pair<uint32_t, uint32_t>> control_ranges = {
159
+ {0x0, 0x8}, {0xE, 0x1B}, {0x7F, 0x84}, {0x86, 0x9F}, {0xAD, 0xAD}, {0x378, 0x379}, {0x380, 0x383}, {0x38B, 0x38B}, {0x38D, 0x38D}, {0x3A2, 0x3A2}, {0x530, 0x530}, {0x557, 0x558}, {0x58B, 0x58C},
160
+ {0x590, 0x590}, {0x5C8, 0x5CF}, {0x5EB, 0x5EE}, {0x5F5, 0x605}, {0x61C, 0x61D}, {0x6DD, 0x6DD}, {0x70E, 0x70F}, {0x74B, 0x74C}, {0x7B2, 0x7BF}, {0x7FB, 0x7FC}, {0x82E, 0x82F}, {0x83F, 0x83F},
161
+ {0x85C, 0x85D}, {0x85F, 0x85F}, {0x86B, 0x89F}, {0x8B5, 0x8B5}, {0x8C8, 0x8D2}, {0x8E2, 0x8E2}, {0x984, 0x984}, {0x98D, 0x98E}, {0x991, 0x992}, {0x9A9, 0x9A9}, {0x9B1, 0x9B1}, {0x9B3, 0x9B5},
162
+ {0x9BA, 0x9BB}, {0x9C5, 0x9C6}, {0x9C9, 0x9CA}, {0x9CF, 0x9D6}, {0x9D8, 0x9DB}, {0x9DE, 0x9DE}, {0x9E4, 0x9E5}, {0x9FF, 0xA00}, {0xA04, 0xA04}, {0xA0B, 0xA0E}, {0xA11, 0xA12}, {0xA29, 0xA29},
163
+ {0xA31, 0xA31}, {0xA34, 0xA34}, {0xA37, 0xA37}, {0xA3A, 0xA3B}, {0xA3D, 0xA3D}, {0xA43, 0xA46}, {0xA49, 0xA4A}, {0xA4E, 0xA50}, {0xA52, 0xA58}, {0xA5D, 0xA5D}, {0xA5F, 0xA65}, {0xA77, 0xA80},
164
+ {0xA84, 0xA84}, {0xA8E, 0xA8E}, {0xA92, 0xA92}, {0xAA9, 0xAA9}, {0xAB1, 0xAB1}, {0xAB4, 0xAB4}, {0xABA, 0xABB}, {0xAC6, 0xAC6}, {0xACA, 0xACA}, {0xACE, 0xACF}, {0xAD1, 0xADF}, {0xAE4, 0xAE5},
165
+ {0xAF2, 0xAF8}, {0xB00, 0xB00}, {0xB04, 0xB04}, {0xB0D, 0xB0E}, {0xB11, 0xB12}, {0xB29, 0xB29}, {0xB31, 0xB31}, {0xB34, 0xB34}, {0xB3A, 0xB3B}, {0xB45, 0xB46}, {0xB49, 0xB4A}, {0xB4E, 0xB54},
166
+ {0xB58, 0xB5B}, {0xB5E, 0xB5E}, {0xB64, 0xB65}, {0xB78, 0xB81}, {0xB84, 0xB84}, {0xB8B, 0xB8D}, {0xB91, 0xB91}, {0xB96, 0xB98}, {0xB9B, 0xB9B}, {0xB9D, 0xB9D}, {0xBA0, 0xBA2}, {0xBA5, 0xBA7},
167
+ {0xBAB, 0xBAD}, {0xBBA, 0xBBD}, {0xBC3, 0xBC5}, {0xBC9, 0xBC9}, {0xBCE, 0xBCF}, {0xBD1, 0xBD6}, {0xBD8, 0xBE5}, {0xBFB, 0xBFF}, {0xC0D, 0xC0D}, {0xC11, 0xC11}, {0xC29, 0xC29}, {0xC3A, 0xC3C},
168
+ {0xC45, 0xC45}, {0xC49, 0xC49}, {0xC4E, 0xC54}, {0xC57, 0xC57}, {0xC5B, 0xC5F}, {0xC64, 0xC65}, {0xC70, 0xC76}, {0xC8D, 0xC8D}, {0xC91, 0xC91}, {0xCA9, 0xCA9}, {0xCB4, 0xCB4}, {0xCBA, 0xCBB},
169
+ {0xCC5, 0xCC5}, {0xCC9, 0xCC9}, {0xCCE, 0xCD4}, {0xCD7, 0xCDD}, {0xCDF, 0xCDF}, {0xCE4, 0xCE5}, {0xCF0, 0xCF0}, {0xCF3, 0xCFF}, {0xD0D, 0xD0D}, {0xD11, 0xD11}, {0xD45, 0xD45}, {0xD49, 0xD49},
170
+ {0xD50, 0xD53}, {0xD64, 0xD65}, {0xD80, 0xD80}, {0xD84, 0xD84}, {0xD97, 0xD99}, {0xDB2, 0xDB2}, {0xDBC, 0xDBC}, {0xDBE, 0xDBF}, {0xDC7, 0xDC9}, {0xDCB, 0xDCE}, {0xDD5, 0xDD5}, {0xDD7, 0xDD7},
171
+ {0xDE0, 0xDE5}, {0xDF0, 0xDF1}, {0xDF5, 0xE00}, {0xE3B, 0xE3E}, {0xE5C, 0xE80}, {0xE83, 0xE83}, {0xE85, 0xE85}, {0xE8B, 0xE8B}, {0xEA4, 0xEA4}, {0xEA6, 0xEA6}, {0xEBE, 0xEBF}, {0xEC5, 0xEC5},
172
+ {0xEC7, 0xEC7}, {0xECE, 0xECF}, {0xEDA, 0xEDB}, {0xEE0, 0xEFF}, {0xF48, 0xF48}, {0xF6D, 0xF70}, {0xF98, 0xF98}, {0xFBD, 0xFBD}, {0xFCD, 0xFCD}, {0xFDB, 0xFFF}, {0x10C6, 0x10C6}, {0x10C8, 0x10CC},
173
+ {0x10CE, 0x10CF}, {0x1249, 0x1249}, {0x124E, 0x124F}, {0x1257, 0x1257}, {0x1259, 0x1259}, {0x125E, 0x125F}, {0x1289, 0x1289}, {0x128E, 0x128F}, {0x12B1, 0x12B1}, {0x12B6, 0x12B7}, {0x12BF, 0x12BF},
174
+ {0x12C1, 0x12C1}, {0x12C6, 0x12C7}, {0x12D7, 0x12D7}, {0x1311, 0x1311}, {0x1316, 0x1317}, {0x135B, 0x135C}, {0x137D, 0x137F}, {0x139A, 0x139F}, {0x13F6, 0x13F7}, {0x13FE, 0x13FF}, {0x169D, 0x169F},
175
+ {0x16F9, 0x16FF}, {0x170D, 0x170D}, {0x1715, 0x171F}, {0x1737, 0x173F}, {0x1754, 0x175F}, {0x176D, 0x176D}, {0x1771, 0x1771}, {0x1774, 0x177F}, {0x17DE, 0x17DF}, {0x17EA, 0x17EF}, {0x17FA, 0x17FF},
176
+ {0x180E, 0x180F}, {0x181A, 0x181F}, {0x1879, 0x187F}, {0x18AB, 0x18AF}, {0x18F6, 0x18FF}, {0x191F, 0x191F}, {0x192C, 0x192F}, {0x193C, 0x193F}, {0x1941, 0x1943}, {0x196E, 0x196F}, {0x1975, 0x197F},
177
+ {0x19AC, 0x19AF}, {0x19CA, 0x19CF}, {0x19DB, 0x19DD}, {0x1A1C, 0x1A1D}, {0x1A5F, 0x1A5F}, {0x1A7D, 0x1A7E}, {0x1A8A, 0x1A8F}, {0x1A9A, 0x1A9F}, {0x1AAE, 0x1AAF}, {0x1AC1, 0x1AFF}, {0x1B4C, 0x1B4F},
178
+ {0x1B7D, 0x1B7F}, {0x1BF4, 0x1BFB}, {0x1C38, 0x1C3A}, {0x1C4A, 0x1C4C}, {0x1C89, 0x1C8F}, {0x1CBB, 0x1CBC}, {0x1CC8, 0x1CCF}, {0x1CFB, 0x1CFF}, {0x1DFA, 0x1DFA}, {0x1F16, 0x1F17}, {0x1F1E, 0x1F1F},
179
+ {0x1F46, 0x1F47}, {0x1F4E, 0x1F4F}, {0x1F58, 0x1F58}, {0x1F5A, 0x1F5A}, {0x1F5C, 0x1F5C}, {0x1F5E, 0x1F5E}, {0x1F7E, 0x1F7F}, {0x1FB5, 0x1FB5}, {0x1FC5, 0x1FC5}, {0x1FD4, 0x1FD5}, {0x1FDC, 0x1FDC},
180
+ {0x1FF0, 0x1FF1}, {0x1FF5, 0x1FF5}, {0x1FFF, 0x1FFF}, {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x206F}, {0x2072, 0x2073}, {0x208F, 0x208F}, {0x209D, 0x209F}, {0x20C0, 0x20CF}, {0x20F1, 0x20FF},
181
+ {0x218C, 0x218F}, {0x2427, 0x243F}, {0x244B, 0x245F}, {0x2B74, 0x2B75}, {0x2B96, 0x2B96}, {0x2C2F, 0x2C2F}, {0x2C5F, 0x2C5F}, {0x2CF4, 0x2CF8}, {0x2D26, 0x2D26}, {0x2D28, 0x2D2C}, {0x2D2E, 0x2D2F},
182
+ {0x2D68, 0x2D6E}, {0x2D71, 0x2D7E}, {0x2D97, 0x2D9F}, {0x2DA7, 0x2DA7}, {0x2DAF, 0x2DAF}, {0x2DB7, 0x2DB7}, {0x2DBF, 0x2DBF}, {0x2DC7, 0x2DC7}, {0x2DCF, 0x2DCF}, {0x2DD7, 0x2DD7}, {0x2DDF, 0x2DDF},
183
+ {0x2E53, 0x2E7F}, {0x2E9A, 0x2E9A}, {0x2EF4, 0x2EFF}, {0x2FD6, 0x2FEF}, {0x2FFC, 0x2FFF}, {0x3040, 0x3040}, {0x3097, 0x3098}, {0x3100, 0x3104}, {0x3130, 0x3130}, {0x318F, 0x318F}, {0x31E4, 0x31EF},
184
+ {0x321F, 0x321F}, {0x9FFD, 0x9FFF}, {0xA48D, 0xA48F}, {0xA4C7, 0xA4CF}, {0xA62C, 0xA63F}, {0xA6F8, 0xA6FF}, {0xA7C0, 0xA7C1}, {0xA7CB, 0xA7F4}, {0xA82D, 0xA82F}, {0xA83A, 0xA83F}, {0xA878, 0xA87F},
185
+ {0xA8C6, 0xA8CD}, {0xA8DA, 0xA8DF}, {0xA954, 0xA95E}, {0xA97D, 0xA97F}, {0xA9CE, 0xA9CE}, {0xA9DA, 0xA9DD}, {0xA9FF, 0xA9FF}, {0xAA37, 0xAA3F}, {0xAA4E, 0xAA4F}, {0xAA5A, 0xAA5B}, {0xAAC3, 0xAADA},
186
+ {0xAAF7, 0xAB00}, {0xAB07, 0xAB08}, {0xAB0F, 0xAB10}, {0xAB17, 0xAB1F}, {0xAB27, 0xAB27}, {0xAB2F, 0xAB2F}, {0xAB6C, 0xAB6F}, {0xABEE, 0xABEF}, {0xABFA, 0xABFF}, {0xD7A4, 0xD7AF}, {0xD7C7, 0xD7CA},
187
+ {0xD7FC, 0xF8FF}, {0xFA6E, 0xFA6F}, {0xFADA, 0xFAFF}, {0xFB07, 0xFB12}, {0xFB18, 0xFB1C}, {0xFB37, 0xFB37}, {0xFB3D, 0xFB3D}, {0xFB3F, 0xFB3F}, {0xFB42, 0xFB42}, {0xFB45, 0xFB45}, {0xFBC2, 0xFBD2},
188
+ {0xFD40, 0xFD4F}, {0xFD90, 0xFD91}, {0xFDC8, 0xFDEF}, {0xFDFE, 0xFDFF}, {0xFE1A, 0xFE1F}, {0xFE53, 0xFE53}, {0xFE67, 0xFE67}, {0xFE6C, 0xFE6F}, {0xFE75, 0xFE75}, {0xFEFD, 0xFF00}, {0xFFBF, 0xFFC1},
189
+ {0xFFC8, 0xFFC9}, {0xFFD0, 0xFFD1}, {0xFFD8, 0xFFD9}, {0xFFDD, 0xFFDF}, {0xFFE7, 0xFFE7}, {0xFFEF, 0xFFFB}, {0xFFFE, 0xFFFF}, {0x1000C, 0x1000C}, {0x10027, 0x10027}, {0x1003B, 0x1003B},
190
+ {0x1003E, 0x1003E}, {0x1004E, 0x1004F}, {0x1005E, 0x1007F}, {0x100FB, 0x100FF}, {0x10103, 0x10106}, {0x10134, 0x10136}, {0x1018F, 0x1018F}, {0x1019D, 0x1019F}, {0x101A1, 0x101CF}, {0x101FE, 0x1027F},
191
+ {0x1029D, 0x1029F}, {0x102D1, 0x102DF}, {0x102FC, 0x102FF}, {0x10324, 0x1032C}, {0x1034B, 0x1034F}, {0x1037B, 0x1037F}, {0x1039E, 0x1039E}, {0x103C4, 0x103C7}, {0x103D6, 0x103FF}, {0x1049E, 0x1049F},
192
+ {0x104AA, 0x104AF}, {0x104D4, 0x104D7}, {0x104FC, 0x104FF}, {0x10528, 0x1052F}, {0x10564, 0x1056E}, {0x10570, 0x105FF}, {0x10737, 0x1073F}, {0x10756, 0x1075F}, {0x10768, 0x107FF}, {0x10806, 0x10807},
193
+ {0x10809, 0x10809}, {0x10836, 0x10836}, {0x10839, 0x1083B}, {0x1083D, 0x1083E}, {0x10856, 0x10856}, {0x1089F, 0x108A6}, {0x108B0, 0x108DF}, {0x108F3, 0x108F3}, {0x108F6, 0x108FA}, {0x1091C, 0x1091E},
194
+ {0x1093A, 0x1093E}, {0x10940, 0x1097F}, {0x109B8, 0x109BB}, {0x109D0, 0x109D1}, {0x10A04, 0x10A04}, {0x10A07, 0x10A0B}, {0x10A14, 0x10A14}, {0x10A18, 0x10A18}, {0x10A36, 0x10A37}, {0x10A3B, 0x10A3E},
195
+ {0x10A49, 0x10A4F}, {0x10A59, 0x10A5F}, {0x10AA0, 0x10ABF}, {0x10AE7, 0x10AEA}, {0x10AF7, 0x10AFF}, {0x10B36, 0x10B38}, {0x10B56, 0x10B57}, {0x10B73, 0x10B77}, {0x10B92, 0x10B98}, {0x10B9D, 0x10BA8},
196
+ {0x10BB0, 0x10BFF}, {0x10C49, 0x10C7F}, {0x10CB3, 0x10CBF}, {0x10CF3, 0x10CF9}, {0x10D28, 0x10D2F}, {0x10D3A, 0x10E5F}, {0x10E7F, 0x10E7F}, {0x10EAA, 0x10EAA}, {0x10EAE, 0x10EAF}, {0x10EB2, 0x10EFF},
197
+ {0x10F28, 0x10F2F}, {0x10F5A, 0x10FAF}, {0x10FCC, 0x10FDF}, {0x10FF7, 0x10FFF}, {0x1104E, 0x11051}, {0x11070, 0x1107E}, {0x110BD, 0x110BD}, {0x110C2, 0x110CF}, {0x110E9, 0x110EF}, {0x110FA, 0x110FF},
198
+ {0x11135, 0x11135}, {0x11148, 0x1114F}, {0x11177, 0x1117F}, {0x111E0, 0x111E0}, {0x111F5, 0x111FF}, {0x11212, 0x11212}, {0x1123F, 0x1127F}, {0x11287, 0x11287}, {0x11289, 0x11289}, {0x1128E, 0x1128E},
199
+ {0x1129E, 0x1129E}, {0x112AA, 0x112AF}, {0x112EB, 0x112EF}, {0x112FA, 0x112FF}, {0x11304, 0x11304}, {0x1130D, 0x1130E}, {0x11311, 0x11312}, {0x11329, 0x11329}, {0x11331, 0x11331}, {0x11334, 0x11334},
200
+ {0x1133A, 0x1133A}, {0x11345, 0x11346}, {0x11349, 0x1134A}, {0x1134E, 0x1134F}, {0x11351, 0x11356}, {0x11358, 0x1135C}, {0x11364, 0x11365}, {0x1136D, 0x1136F}, {0x11375, 0x113FF}, {0x1145C, 0x1145C},
201
+ {0x11462, 0x1147F}, {0x114C8, 0x114CF}, {0x114DA, 0x1157F}, {0x115B6, 0x115B7}, {0x115DE, 0x115FF}, {0x11645, 0x1164F}, {0x1165A, 0x1165F}, {0x1166D, 0x1167F}, {0x116B9, 0x116BF}, {0x116CA, 0x116FF},
202
+ {0x1171B, 0x1171C}, {0x1172C, 0x1172F}, {0x11740, 0x117FF}, {0x1183C, 0x1189F}, {0x118F3, 0x118FE}, {0x11907, 0x11908}, {0x1190A, 0x1190B}, {0x11914, 0x11914}, {0x11917, 0x11917}, {0x11936, 0x11936},
203
+ {0x11939, 0x1193A}, {0x11947, 0x1194F}, {0x1195A, 0x1199F}, {0x119A8, 0x119A9}, {0x119D8, 0x119D9}, {0x119E5, 0x119FF}, {0x11A48, 0x11A4F}, {0x11AA3, 0x11ABF}, {0x11AF9, 0x11BFF}, {0x11C09, 0x11C09},
204
+ {0x11C37, 0x11C37}, {0x11C46, 0x11C4F}, {0x11C6D, 0x11C6F}, {0x11C90, 0x11C91}, {0x11CA8, 0x11CA8}, {0x11CB7, 0x11CFF}, {0x11D07, 0x11D07}, {0x11D0A, 0x11D0A}, {0x11D37, 0x11D39}, {0x11D3B, 0x11D3B},
205
+ {0x11D3E, 0x11D3E}, {0x11D48, 0x11D4F}, {0x11D5A, 0x11D5F}, {0x11D66, 0x11D66}, {0x11D69, 0x11D69}, {0x11D8F, 0x11D8F}, {0x11D92, 0x11D92}, {0x11D99, 0x11D9F}, {0x11DAA, 0x11EDF}, {0x11EF9, 0x11FAF},
206
+ {0x11FB1, 0x11FBF}, {0x11FF2, 0x11FFE}, {0x1239A, 0x123FF}, {0x1246F, 0x1246F}, {0x12475, 0x1247F}, {0x12544, 0x12FFF}, {0x1342F, 0x143FF}, {0x14647, 0x167FF}, {0x16A39, 0x16A3F}, {0x16A5F, 0x16A5F},
207
+ {0x16A6A, 0x16A6D}, {0x16A70, 0x16ACF}, {0x16AEE, 0x16AEF}, {0x16AF6, 0x16AFF}, {0x16B46, 0x16B4F}, {0x16B5A, 0x16B5A}, {0x16B62, 0x16B62}, {0x16B78, 0x16B7C}, {0x16B90, 0x16E3F}, {0x16E9B, 0x16EFF},
208
+ {0x16F4B, 0x16F4E}, {0x16F88, 0x16F8E}, {0x16FA0, 0x16FDF}, {0x16FE5, 0x16FEF}, {0x16FF2, 0x16FFF}, {0x187F8, 0x187FF}, {0x18CD6, 0x18CFF}, {0x18D09, 0x1AFFF}, {0x1B11F, 0x1B14F}, {0x1B153, 0x1B163},
209
+ {0x1B168, 0x1B16F}, {0x1B2FC, 0x1BBFF}, {0x1BC6B, 0x1BC6F}, {0x1BC7D, 0x1BC7F}, {0x1BC89, 0x1BC8F}, {0x1BC9A, 0x1BC9B}, {0x1BCA0, 0x1CFFF}, {0x1D0F6, 0x1D0FF}, {0x1D127, 0x1D128}, {0x1D173, 0x1D17A},
210
+ {0x1D1E9, 0x1D1FF}, {0x1D246, 0x1D2DF}, {0x1D2F4, 0x1D2FF}, {0x1D357, 0x1D35F}, {0x1D379, 0x1D3FF}, {0x1D455, 0x1D455}, {0x1D49D, 0x1D49D}, {0x1D4A0, 0x1D4A1}, {0x1D4A3, 0x1D4A4}, {0x1D4A7, 0x1D4A8},
211
+ {0x1D4AD, 0x1D4AD}, {0x1D4BA, 0x1D4BA}, {0x1D4BC, 0x1D4BC}, {0x1D4C4, 0x1D4C4}, {0x1D506, 0x1D506}, {0x1D50B, 0x1D50C}, {0x1D515, 0x1D515}, {0x1D51D, 0x1D51D}, {0x1D53A, 0x1D53A}, {0x1D53F, 0x1D53F},
212
+ {0x1D545, 0x1D545}, {0x1D547, 0x1D549}, {0x1D551, 0x1D551}, {0x1D6A6, 0x1D6A7}, {0x1D7CC, 0x1D7CD}, {0x1DA8C, 0x1DA9A}, {0x1DAA0, 0x1DAA0}, {0x1DAB0, 0x1DFFF}, {0x1E007, 0x1E007}, {0x1E019, 0x1E01A},
213
+ {0x1E022, 0x1E022}, {0x1E025, 0x1E025}, {0x1E02B, 0x1E0FF}, {0x1E12D, 0x1E12F}, {0x1E13E, 0x1E13F}, {0x1E14A, 0x1E14D}, {0x1E150, 0x1E2BF}, {0x1E2FA, 0x1E2FE}, {0x1E300, 0x1E7FF}, {0x1E8C5, 0x1E8C6},
214
+ {0x1E8D7, 0x1E8FF}, {0x1E94C, 0x1E94F}, {0x1E95A, 0x1E95D}, {0x1E960, 0x1EC70}, {0x1ECB5, 0x1ED00}, {0x1ED3E, 0x1EDFF}, {0x1EE04, 0x1EE04}, {0x1EE20, 0x1EE20}, {0x1EE23, 0x1EE23}, {0x1EE25, 0x1EE26},
215
+ {0x1EE28, 0x1EE28}, {0x1EE33, 0x1EE33}, {0x1EE38, 0x1EE38}, {0x1EE3A, 0x1EE3A}, {0x1EE3C, 0x1EE41}, {0x1EE43, 0x1EE46}, {0x1EE48, 0x1EE48}, {0x1EE4A, 0x1EE4A}, {0x1EE4C, 0x1EE4C}, {0x1EE50, 0x1EE50},
216
+ {0x1EE53, 0x1EE53}, {0x1EE55, 0x1EE56}, {0x1EE58, 0x1EE58}, {0x1EE5A, 0x1EE5A}, {0x1EE5C, 0x1EE5C}, {0x1EE5E, 0x1EE5E}, {0x1EE60, 0x1EE60}, {0x1EE63, 0x1EE63}, {0x1EE65, 0x1EE66}, {0x1EE6B, 0x1EE6B},
217
+ {0x1EE73, 0x1EE73}, {0x1EE78, 0x1EE78}, {0x1EE7D, 0x1EE7D}, {0x1EE7F, 0x1EE7F}, {0x1EE8A, 0x1EE8A}, {0x1EE9C, 0x1EEA0}, {0x1EEA4, 0x1EEA4}, {0x1EEAA, 0x1EEAA}, {0x1EEBC, 0x1EEEF}, {0x1EEF2, 0x1EFFF},
218
+ {0x1F02C, 0x1F02F}, {0x1F094, 0x1F09F}, {0x1F0AF, 0x1F0B0}, {0x1F0C0, 0x1F0C0}, {0x1F0D0, 0x1F0D0}, {0x1F0F6, 0x1F0FF}, {0x1F1AE, 0x1F1E5}, {0x1F203, 0x1F20F}, {0x1F23C, 0x1F23F}, {0x1F249, 0x1F24F},
219
+ {0x1F252, 0x1F25F}, {0x1F266, 0x1F2FF}, {0x1F6D8, 0x1F6DF}, {0x1F6ED, 0x1F6EF}, {0x1F6FD, 0x1F6FF}, {0x1F774, 0x1F77F}, {0x1F7D9, 0x1F7DF}, {0x1F7EC, 0x1F7FF}, {0x1F80C, 0x1F80F}, {0x1F848, 0x1F84F},
220
+ {0x1F85A, 0x1F85F}, {0x1F888, 0x1F88F}, {0x1F8AE, 0x1F8AF}, {0x1F8B2, 0x1F8FF}, {0x1F979, 0x1F979}, {0x1F9CC, 0x1F9CC}, {0x1FA54, 0x1FA5F}, {0x1FA6E, 0x1FA6F}, {0x1FA75, 0x1FA77}, {0x1FA7B, 0x1FA7F},
221
+ {0x1FA87, 0x1FA8F}, {0x1FAA9, 0x1FAAF}, {0x1FAB7, 0x1FABF}, {0x1FAC3, 0x1FACF}, {0x1FAD7, 0x1FAFF}, {0x1FB93, 0x1FB93}, {0x1FBCB, 0x1FBEF}, {0x1FBFA, 0x1FFFF}, {0x2A6DE, 0x2A6FF}, {0x2B735, 0x2B73F},
222
+ {0x2B81E, 0x2B81F}, {0x2CEA2, 0x2CEAF}, {0x2EBE1, 0x2F7FF}, {0x2FA1E, 0x2FFFF}, {0x3134B, 0xE00FF}, {0xE01F0, 0x10FFFF},
223
+ };
224
+
225
+ static std::string codepoint_to_utf8(uint32_t cp) {
226
+ std::string result;
227
+ if (/* 0x00 <= cp && */ cp <= 0x7f) {
228
+ result.push_back(cp);
229
+ }
230
+ else if (0x80 <= cp && cp <= 0x7ff) {
231
+ result.push_back(0xc0 | ((cp >> 6) & 0x1f));
232
+ result.push_back(0x80 | (cp & 0x3f));
233
+ }
234
+ else if (0x800 <= cp && cp <= 0xffff) {
235
+ result.push_back(0xe0 | ((cp >> 12) & 0x0f));
236
+ result.push_back(0x80 | ((cp >> 6) & 0x3f));
237
+ result.push_back(0x80 | (cp & 0x3f));
238
+ }
239
+ else if (0x10000 <= cp && cp <= 0x10ffff) {
240
+ result.push_back(0xf0 | ((cp >> 18) & 0x07));
241
+ result.push_back(0x80 | ((cp >> 12) & 0x3f));
242
+ result.push_back(0x80 | ((cp >> 6) & 0x3f));
243
+ result.push_back(0x80 | (cp & 0x3f));
244
+ }
245
+ else {
246
+ throw std::invalid_argument("invalid codepoint");
247
+ }
248
+ return result;
249
+ }
250
+
251
+ static std::string codepoints_to_utf8(const std::vector<uint32_t> & cps) {
252
+ std::string result;
253
+ for (size_t i = 0; i < cps.size(); ++i) {
254
+ result.append(codepoint_to_utf8(cps[i]));
255
+ }
256
+ return result;
257
+ }
258
+
259
+ static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) {
260
+ assert(offset < utf8.size());
261
+ if (!(utf8[offset + 0] & 0x80)) {
262
+ auto result = utf8[offset + 0];
263
+ offset += 1;
264
+ return result;
265
+ }
266
+ else if (!(utf8[offset + 0] & 0x40)) {
267
+ throw std::invalid_argument("invalid character");
268
+ }
269
+ else if (!(utf8[offset + 0] & 0x20)) {
270
+ if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80))
271
+ throw std::invalid_argument("invalid character");
272
+ auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
273
+ offset += 2;
274
+ return result;
275
+ }
276
+ else if (!(utf8[offset + 0] & 0x10)) {
277
+ if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80))
278
+ throw std::invalid_argument("invalid character");
279
+ auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
280
+ offset += 3;
281
+ return result;
282
+ }
283
+ else if (!(utf8[offset + 0] & 0x08)) {
284
+ if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80))
285
+ throw std::invalid_argument("invalid character");
286
+ auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
287
+ offset += 4;
288
+ return result;
289
+ }
290
+ throw std::invalid_argument("invalid string");
291
+ }
292
+
293
+ static std::vector<uint32_t> codepoints_from_utf8(const std::string & utf8) {
294
+ std::vector<uint32_t> result;
295
+ size_t offset = 0;
296
+ while (offset < utf8.size()) {
297
+ result.push_back(codepoint_from_utf8(utf8, offset));
298
+ }
299
+ return result;
300
+ }
301
+
302
+ static std::vector<uint16_t> codepoint_to_utf16(uint32_t cp) {
303
+ std::vector<uint16_t> result;
304
+ if (/* 0x0000 <= cp && */ cp <= 0xffff) {
305
+ result.emplace_back(cp);
306
+ }
307
+ else if (0x10000 <= cp && cp <= 0x10ffff) {
308
+ result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
309
+ result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
310
+ }
311
+ else {
312
+ throw std::invalid_argument("invalid codepoint");
313
+ }
314
+ return result;
315
+ }
316
+
317
+ static std::vector<uint16_t> codepoints_to_utf16(const std::vector<uint32_t> & cps) {
318
+ std::vector<uint16_t> result;
319
+ for (size_t i = 0; i < cps.size(); ++i) {
320
+ auto temp = codepoint_to_utf16(cps[i]);
321
+ result.insert(result.end(), temp.begin(), temp.end());
322
+ }
323
+ return result;
324
+ }
325
+
326
+ static uint32_t codepoint_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
327
+ assert(offset < utf16.size());
328
+ if (((utf16[0] >> 10) << 10) != 0xd800) {
329
+ auto result = utf16[offset + 0];
330
+ offset += 1;
331
+ return result;
332
+ }
333
+ else {
334
+ if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00))
335
+ throw std::invalid_argument("invalid character");
336
+ auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
337
+ offset += 2;
338
+ return result;
339
+ }
340
+ throw std::invalid_argument("invalid string");
341
+ }
342
+
343
+ static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) {
344
+ std::vector<uint32_t> result;
345
+ size_t offset = 0;
346
+ while (offset < utf16.size())
347
+ result.push_back(codepoint_from_utf16(utf16, offset));
348
+ return result;
349
+ }
350
+
351
+ #define CODEPOINT_TYPE_UNIDENTIFIED 0
352
+ #define CODEPOINT_TYPE_DIGIT 1
353
+ #define CODEPOINT_TYPE_LETTER 2
354
+ #define CODEPOINT_TYPE_WHITESPACE 3
355
+ #define CODEPOINT_TYPE_ACCENT_MARK 4
356
+ #define CODEPOINT_TYPE_PUNCTUATION 5
357
+ #define CODEPOINT_TYPE_SYMBOL 6
358
+ #define CODEPOINT_TYPE_CONTROL 7
359
+
360
+ static std::unordered_map<uint32_t, int> codepoint_type_map() {
361
+ std::unordered_map<uint32_t, int> codepoint_types;
362
+ for (auto p : digit_ranges) {
363
+ for(auto i = p.first; i <= p.second; ++ i)
364
+ codepoint_types[i] = CODEPOINT_TYPE_DIGIT;
365
+ }
366
+ for(auto p : letter_ranges) {
367
+ for(auto i = p.first; i <= p.second; ++ i)
368
+ codepoint_types[i] = CODEPOINT_TYPE_LETTER;
369
+ }
370
+ for(auto p : whitespace_ranges) {
371
+ for(auto i = p.first; i <= p.second; ++ i)
372
+ codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE;
373
+ }
374
+ for(auto p : accent_mark_ranges) {
375
+ for(auto i = p.first; i <= p.second; ++ i)
376
+ codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
377
+ }
378
+ for(auto p : punctuation_ranges) {
379
+ for(auto i = p.first; i <= p.second; ++ i)
380
+ codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION;
381
+ }
382
+ for (auto p : symbol_ranges) {
383
+ for (auto i = p.first; i <= p.second; ++i)
384
+ codepoint_types[i] = CODEPOINT_TYPE_SYMBOL;
385
+ }
386
+ for(auto p : control_ranges) {
387
+ for(auto i = p.first; i <= p.second; ++ i)
388
+ codepoint_types[i] = CODEPOINT_TYPE_CONTROL;
389
+ }
390
+ return codepoint_types;
391
+ }
392
+
393
+ static int codepoint_type(uint32_t cp) {
394
+ static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map();
395
+ return codepoint_types[cp];
396
+ }
397
+
398
+ static int codepoint_type(const std::string & utf8) {
399
+ if (utf8.length() == 0)
400
+ return CODEPOINT_TYPE_UNIDENTIFIED;
401
+ size_t offset = 0;
402
+ return codepoint_type(codepoint_from_utf8(utf8, offset));
403
+ }
404
+
405
+ static std::unordered_map<uint8_t, std::string> bytes_to_unicode_map_bpe() {
406
+ std::unordered_map<uint8_t, std::string> map;
407
+ for (int ch = u'!'; ch <= u'~'; ++ch) {
408
+ assert(0 <= ch && ch < 256);
409
+ map[ch] = codepoint_to_utf8(ch);
410
+ }
411
+ for (int ch = u'¡'; ch <= u'¬'; ++ch) {
412
+ assert(0 <= ch && ch < 256);
413
+ map[ch] = codepoint_to_utf8(ch);
414
+ }
415
+ for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
416
+ assert(0 <= ch && ch < 256);
417
+ map[ch] = codepoint_to_utf8(ch);
418
+ }
419
+ auto n = 0;
420
+ for (int ch = 0; ch < 256; ++ch) {
421
+ if (map.find(ch) == map.end()) {
422
+ map[ch] = codepoint_to_utf8(256 + n);
423
+ ++n;
424
+ }
425
+ }
426
+ return map;
427
+ }
428
+
429
+ static std::string bytes_to_unicode_bpe(uint8_t byte) {
430
+ static std::unordered_map<uint8_t, std::string> map = bytes_to_unicode_map_bpe();
431
+ return map.at(byte);
432
+ }
433
+
434
+ static std::unordered_map<std::string, uint8_t> unicode_to_bytes_map_bpe() {
435
+ std::unordered_map<std::string, uint8_t> map;
436
+ for (int ch = u'!'; ch <= u'~'; ++ch) {
437
+ assert(0 <= ch && ch < 256);
438
+ map[codepoint_to_utf8(ch)] = ch;
439
+ }
440
+ for (int ch = u'¡'; ch <= u'¬'; ++ch) {
441
+ assert(0 <= ch && ch < 256);
442
+ map[codepoint_to_utf8(ch)] = ch;
443
+ }
444
+ for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
445
+ assert(0 <= ch && ch < 256);
446
+ map[codepoint_to_utf8(ch)] = ch;
447
+ }
448
+ auto n = 0;
449
+ for (int ch = 0; ch < 256; ++ch) {
450
+ if (map.find(codepoint_to_utf8(ch)) == map.end()) {
451
+ map[codepoint_to_utf8(256 + n)] = ch;
452
+ ++n;
453
+ }
454
+ }
455
+ return map;
456
+ }
457
+
458
+ static uint8_t unicode_to_bytes_bpe(const std::string & utf8) {
459
+ static std::unordered_map<std::string, uint8_t> map = unicode_to_bytes_map_bpe();
460
+ return map.at(utf8);
461
+ }
462
+
examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt CHANGED
@@ -9,6 +9,8 @@ set(
9
  SOURCE_FILES
10
  ${WHISPER_LIB_DIR}/ggml.c
11
  ${WHISPER_LIB_DIR}/ggml-alloc.c
 
 
12
  ${WHISPER_LIB_DIR}/whisper.cpp
13
  ${CMAKE_SOURCE_DIR}/jni.c
14
  )
 
9
  SOURCE_FILES
10
  ${WHISPER_LIB_DIR}/ggml.c
11
  ${WHISPER_LIB_DIR}/ggml-alloc.c
12
+ ${WHISPER_LIB_DIR}/ggml-backend.c
13
+ ${WHISPER_LIB_DIR}/ggml-quants.c
14
  ${WHISPER_LIB_DIR}/whisper.cpp
15
  ${CMAKE_SOURCE_DIR}/jni.c
16
  )
examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj CHANGED
@@ -20,6 +20,8 @@
20
  18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
21
  18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
22
  18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
 
 
23
  7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
24
  7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
25
  7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
@@ -61,6 +63,12 @@
61
  18627C9529052C5800BD2A04 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
62
  18627C9729052C6600BD2A04 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
63
  18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "ggml-base.en.bin"; path = "../../../models/ggml-base.en.bin"; sourceTree = "<group>"; };
 
 
 
 
 
 
64
  7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
65
  7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
66
  7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
@@ -100,6 +108,12 @@
100
  18627C7829052BDF00BD2A04 /* whisper.objc */ = {
101
  isa = PBXGroup;
102
  children = (
 
 
 
 
 
 
103
  1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
104
  1844471B2AB21655007D6BFE /* ggml-metal.m */,
105
  184447182AB211A2007D6BFE /* ggml-alloc.c */,
@@ -214,12 +228,14 @@
214
  buildActionMask = 2147483647;
215
  files = (
216
  18627C8129052BDF00BD2A04 /* ViewController.m in Sources */,
 
217
  7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */,
218
  18627C9429052C4900BD2A04 /* whisper.cpp in Sources */,
219
  18627C9629052C5800BD2A04 /* ggml.c in Sources */,
220
  18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
221
  7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
222
  1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
 
223
  18627C8C29052BE000BD2A04 /* main.m in Sources */,
224
  18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
225
  1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
 
20
  18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
21
  18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
22
  18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
23
+ 18ABE15A2AF556340044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1572AF556340044A204 /* ggml-backend.c */; };
24
+ 18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1592AF556340044A204 /* ggml-quants.c */; };
25
  7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
26
  7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
27
  7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
 
63
  18627C9529052C5800BD2A04 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
64
  18627C9729052C6600BD2A04 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
65
  18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "ggml-base.en.bin"; path = "../../../models/ggml-base.en.bin"; sourceTree = "<group>"; };
66
+ 18ABE1542AF556340044A204 /* ggml-quants.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-quants.h"; path = "../../../ggml-quants.h"; sourceTree = "<group>"; };
67
+ 18ABE1552AF556340044A204 /* ggml-backend.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-backend.h"; path = "../../../ggml-backend.h"; sourceTree = "<group>"; };
68
+ 18ABE1562AF556340044A204 /* ggml-backend-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-backend-impl.h"; path = "../../../ggml-backend-impl.h"; sourceTree = "<group>"; };
69
+ 18ABE1572AF556340044A204 /* ggml-backend.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-backend.c"; path = "../../../ggml-backend.c"; sourceTree = "<group>"; };
70
+ 18ABE1582AF556340044A204 /* ggml-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-impl.h"; path = "../../../ggml-impl.h"; sourceTree = "<group>"; };
71
+ 18ABE1592AF556340044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-quants.c"; path = "../../../ggml-quants.c"; sourceTree = "<group>"; };
72
  7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
73
  7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
74
  7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
 
108
  18627C7829052BDF00BD2A04 /* whisper.objc */ = {
109
  isa = PBXGroup;
110
  children = (
111
+ 18ABE1562AF556340044A204 /* ggml-backend-impl.h */,
112
+ 18ABE1572AF556340044A204 /* ggml-backend.c */,
113
+ 18ABE1552AF556340044A204 /* ggml-backend.h */,
114
+ 18ABE1582AF556340044A204 /* ggml-impl.h */,
115
+ 18ABE1592AF556340044A204 /* ggml-quants.c */,
116
+ 18ABE1542AF556340044A204 /* ggml-quants.h */,
117
  1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
118
  1844471B2AB21655007D6BFE /* ggml-metal.m */,
119
  184447182AB211A2007D6BFE /* ggml-alloc.c */,
 
228
  buildActionMask = 2147483647;
229
  files = (
230
  18627C8129052BDF00BD2A04 /* ViewController.m in Sources */,
231
+ 18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */,
232
  7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */,
233
  18627C9429052C4900BD2A04 /* whisper.cpp in Sources */,
234
  18627C9629052C5800BD2A04 /* ggml.c in Sources */,
235
  18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
236
  7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
237
  1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
238
+ 18ABE15A2AF556340044A204 /* ggml-backend.c in Sources */,
239
  18627C8C29052BE000BD2A04 /* main.m in Sources */,
240
  18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
241
  1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj CHANGED
@@ -20,6 +20,8 @@
20
  0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
21
  0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
22
  0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
 
 
23
  18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
24
  /* End PBXBuildFile section */
25
 
@@ -42,6 +44,12 @@
42
  0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
43
  0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
44
  0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
 
 
 
 
 
 
45
  18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
46
  18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
47
  /* End PBXFileReference section */
@@ -127,6 +135,12 @@
127
  0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
128
  isa = PBXGroup;
129
  children = (
 
 
 
 
 
 
130
  18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
131
  18AED4802AB21F2B009D854F /* ggml-alloc.h */,
132
  0AAC5DC929539EB0003032C3 /* ggml.c */,
@@ -242,12 +256,14 @@
242
  0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */,
243
  0AAC5D9B29539CCF003032C3 /* WhisperCppDemoApp.swift in Sources */,
244
  0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */,
 
245
  0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */,
246
  0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */,
247
  0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
248
  0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
249
  0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
250
  18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
 
251
  );
252
  runOnlyForDeploymentPostprocessing = 0;
253
  };
 
20
  0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
21
  0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
22
  0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
23
+ 18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE14C2AF555FA0044A204 /* ggml-backend.c */; };
24
+ 18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1512AF555FA0044A204 /* ggml-quants.c */; };
25
  18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
26
  /* End PBXBuildFile section */
27
 
 
44
  0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
45
  0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
46
  0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
47
+ 18ABE14C2AF555FA0044A204 /* ggml-backend.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-backend.c"; sourceTree = "<group>"; };
48
+ 18ABE14D2AF555FA0044A204 /* ggml-backend.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-backend.h"; sourceTree = "<group>"; };
49
+ 18ABE14E2AF555FA0044A204 /* ggml-backend-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-backend-impl.h"; sourceTree = "<group>"; };
50
+ 18ABE14F2AF555FA0044A204 /* ggml-quants.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-quants.h"; sourceTree = "<group>"; };
51
+ 18ABE1502AF555FA0044A204 /* ggml-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-impl.h"; sourceTree = "<group>"; };
52
+ 18ABE1512AF555FA0044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-quants.c"; sourceTree = "<group>"; };
53
  18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
54
  18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
55
  /* End PBXFileReference section */
 
135
  0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
136
  isa = PBXGroup;
137
  children = (
138
+ 18ABE14E2AF555FA0044A204 /* ggml-backend-impl.h */,
139
+ 18ABE14C2AF555FA0044A204 /* ggml-backend.c */,
140
+ 18ABE14D2AF555FA0044A204 /* ggml-backend.h */,
141
+ 18ABE1502AF555FA0044A204 /* ggml-impl.h */,
142
+ 18ABE1512AF555FA0044A204 /* ggml-quants.c */,
143
+ 18ABE14F2AF555FA0044A204 /* ggml-quants.h */,
144
  18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
145
  18AED4802AB21F2B009D854F /* ggml-alloc.h */,
146
  0AAC5DC929539EB0003032C3 /* ggml.c */,
 
256
  0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */,
257
  0AAC5D9B29539CCF003032C3 /* WhisperCppDemoApp.swift in Sources */,
258
  0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */,
259
+ 18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */,
260
  0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */,
261
  0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */,
262
  0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
263
  0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
264
  0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
265
  18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
266
+ 18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */,
267
  );
268
  runOnlyForDeploymentPostprocessing = 0;
269
  };
extra/sync-ggml.sh CHANGED
@@ -1,20 +1,30 @@
1
  #!/bin/bash
2
 
3
- cp -rpv ../ggml/src/ggml.c ./ggml.c
4
- cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
5
- cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
6
- cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
7
- cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
8
- cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
9
- cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
10
- cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
11
- cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
12
- cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
13
- cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
14
- cp -rpv ../ggml/examples/common.h ./examples/common.h
15
- cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
16
- cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
17
- cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
 
 
 
 
 
 
 
 
 
 
18
 
19
  cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
20
  cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
 
1
  #!/bin/bash
2
 
3
+ cp -rpv ../ggml/src/ggml.c ./ggml.c
4
+ cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
5
+ cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml-backend-impl.h
6
+ cp -rpv ../ggml/src/ggml-backend.c ./ggml-backend.c
7
+ cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
8
+ cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
9
+ cp -rpv ../ggml/src/ggml-impl.h ./ggml-impl.h
10
+ cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
11
+ cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
12
+ cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
13
+ #cp -rpv ../ggml/src/ggml-mpi.h ./ggml-mpi.h
14
+ #cp -rpv ../ggml/src/ggml-mpi.c ./ggml-mpi.c
15
+ cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
16
+ cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
17
+ cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c
18
+ cp -rpv ../ggml/src/ggml-quants.h ./ggml-quants.h
19
+
20
+ cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
21
+ cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
22
+ cp -rpv ../ggml/include/ggml/ggml-backend.h ./ggml-backend.h
23
+
24
+ cp -rpv ../ggml/examples/common.h ./examples/common.h
25
+ cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
26
+ cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
27
+ cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
28
 
29
  cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
30
  cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
ggml-alloc.c CHANGED
@@ -1,69 +1,21 @@
1
  #include "ggml-alloc.h"
 
2
  #include "ggml.h"
 
3
  #include <assert.h>
 
4
  #include <stdarg.h>
5
  #include <stdio.h>
6
  #include <stdlib.h>
7
  #include <string.h>
8
 
9
- #ifdef __has_include
10
- #if __has_include(<unistd.h>)
11
- #include <unistd.h>
12
- #if defined(_POSIX_MAPPED_FILES)
13
- #include <sys/types.h>
14
- #include <sys/mman.h>
15
- #endif
16
- #endif
17
- #endif
18
-
19
- #if defined(_WIN32)
20
- #define WIN32_LEAN_AND_MEAN
21
- #ifndef NOMINMAX
22
- #define NOMINMAX
23
- #endif
24
- #include <windows.h>
25
- #include <memoryapi.h>
26
- #endif
27
-
28
-
29
- #define UNUSED(x) (void)(x)
30
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
31
- #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
32
 
33
  //#define GGML_ALLOCATOR_DEBUG
34
 
35
- //#define AT_PRINTF printf
36
- #define AT_PRINTF(...) ((void)0)
37
-
38
- struct hash_node {
39
- struct ggml_tensor * t;
40
- int n_children;
41
- int n_views;
42
- };
43
-
44
- static size_t hash(void * p) {
45
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
46
- }
47
-
48
- static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
49
- size_t h = hash(t);
50
-
51
- // linear probing
52
- size_t i = h;
53
- while (hash_table[i].t != NULL) {
54
- if (hash_table[i].t == t) {
55
- return &hash_table[i];
56
- }
57
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
58
- if (i == h) {
59
- // hash table is full
60
- GGML_ASSERT(false);
61
- }
62
- }
63
-
64
- hash_table[i].t = t;
65
- return &hash_table[i];
66
- }
67
 
68
  // TODO: GGML_PAD ?
69
  static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
@@ -77,19 +29,18 @@ struct free_block {
77
  size_t size;
78
  };
79
 
80
- #define MAX_FREE_BLOCKS 128
81
-
82
- struct ggml_allocr {
83
- void * data;
84
- size_t size;
85
  size_t alignment;
 
86
  int n_free_blocks;
87
  struct free_block free_blocks[MAX_FREE_BLOCKS];
88
- struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
89
  size_t max_size;
 
90
  bool measure;
91
- int parse_seq[GGML_MAX_CONCUR];
92
- int parse_seq_len;
93
 
94
  #ifdef GGML_ALLOCATOR_DEBUG
95
  struct ggml_tensor * allocated_tensors[1024];
@@ -97,7 +48,7 @@ struct ggml_allocr {
97
  };
98
 
99
  #ifdef GGML_ALLOCATOR_DEBUG
100
- static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
101
  for (int i = 0; i < 1024; i++) {
102
  if (alloc->allocated_tensors[i] == NULL) {
103
  alloc->allocated_tensors[i] = tensor;
@@ -106,7 +57,7 @@ static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor
106
  }
107
  GGML_ASSERT(!"out of allocated_tensors");
108
  }
109
- static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
110
  for (int i = 0; i < 1024; i++) {
111
  if (alloc->allocated_tensors[i] == tensor ||
112
  (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
@@ -119,28 +70,20 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
119
  }
120
  #endif
121
 
122
- static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
123
- return ggml_nbytes(tensor);
124
-
125
- UNUSED(alloc);
126
- }
127
-
128
  // check if a tensor is allocated by this buffer
129
- static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
130
- void * ptr = tensor->data;
131
- return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
132
  }
133
 
134
  static bool ggml_is_view(struct ggml_tensor * t) {
135
  return t->view_src != NULL;
136
  }
137
 
138
- void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
139
- #ifdef GGML_ALLOCATOR_DEBUG
140
  GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
141
  GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
142
- #endif
143
- size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
144
  size = aligned_offset(NULL, size, alloc->alignment);
145
 
146
  AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
@@ -187,6 +130,10 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
187
  }
188
 
189
  tensor->data = addr;
 
 
 
 
190
 
191
  #ifdef GGML_ALLOCATOR_DEBUG
192
  add_allocated_tensor(alloc, tensor);
@@ -202,23 +149,28 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
202
  }
203
  #endif
204
 
205
- alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
206
  }
207
 
208
  // this is a very naive implementation, but for our case the number of free blocks should be very small
209
- static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
210
- void * ptr = tensor->data;
211
-
212
- if (ggml_allocr_is_own(alloc, tensor) == false) {
213
  // the tensor was not allocated in this buffer
214
  // this can happen because the graph allocator will try to free weights and other tensors from different buffers
215
  // the easiest way to deal with this is just to ignore it
 
216
  return;
217
  }
218
 
219
- size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
 
 
220
  size = aligned_offset(NULL, size, alloc->alignment);
221
- AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
 
 
 
 
222
 
223
  #ifdef GGML_ALLOCATOR_DEBUG
224
  remove_allocated_tensor(alloc, tensor);
@@ -272,136 +224,180 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
272
  alloc->n_free_blocks++;
273
  }
274
 
275
- void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
276
- for (int i = 0; i < n; i++) {
277
- alloc->parse_seq[i] = list[i];
 
 
 
 
 
 
278
  }
279
- alloc->parse_seq_len = n;
280
  }
281
 
282
- void ggml_allocr_reset(struct ggml_allocr * alloc) {
283
- alloc->n_free_blocks = 1;
284
- size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
285
- alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
286
- alloc->free_blocks[0].size = alloc->size - align_offset;
287
- }
288
 
289
- struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
290
- struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
291
 
292
- *alloc = (struct ggml_allocr){
293
- /*.data = */ data,
294
- /*.size = */ size,
 
295
  /*.alignment = */ alignment,
296
  /*.n_free_blocks = */ 0,
297
  /*.free_blocks = */ {{0}},
298
- /*.hash_table = */ {{0}},
299
  /*.max_size = */ 0,
300
  /*.measure = */ false,
301
- /*.parse_seq = */ {0},
302
- /*.parse_seq_len = */ 0,
303
  #ifdef GGML_ALLOCATOR_DEBUG
304
  /*.allocated_tensors = */ {0},
305
  #endif
306
  };
307
 
308
- ggml_allocr_reset(alloc);
309
 
310
  return alloc;
311
  }
312
 
313
- // OS specific functions to allocate and free uncommitted virtual memory
314
- static void * alloc_vmem(size_t size) {
315
- #if defined(_WIN32)
316
- return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
317
- #elif defined(_POSIX_MAPPED_FILES)
318
- void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
319
- if (ptr == MAP_FAILED) {
320
- return NULL;
321
- }
322
- return ptr;
323
- #else
324
- // use a fixed address for other platforms
325
- uintptr_t base_addr = (uintptr_t)-size - 0x100;
326
- return (void *)base_addr;
327
- #endif
328
- }
329
 
330
- static void free_vmem(void * base_addr, size_t size) {
331
- #if defined(_WIN32)
332
- VirtualFree(base_addr, 0, MEM_RELEASE);
333
- UNUSED(size);
334
- #elif defined(_POSIX_MAPPED_FILES)
335
- munmap(base_addr, size);
336
- #else
337
- // nothing to do
338
- UNUSED(base_addr);
339
- UNUSED(size);
340
- #endif
341
  }
342
 
343
- // allocate uncommitted virtual memory to measure the size of the graph
344
- static void alloc_measure_vmem(void ** base_addr, size_t * size) {
345
- // 128GB for 64-bit, 1GB for 32-bit
346
- *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
347
- do {
348
- *base_addr = alloc_vmem(*size);
349
- if (*base_addr != NULL) {
350
- AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
351
- return;
352
- }
353
- // try again with half the size
354
- *size /= 2;
355
- } while (*size > 0);
356
 
357
- GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
 
 
 
 
 
358
  }
359
 
360
- static void free_measure_vmem(void * base_addr, size_t size) {
361
- free_vmem(base_addr, size);
 
 
 
362
  }
363
 
364
- struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
365
- struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
366
 
367
- void * base_addr;
368
- size_t size;
369
-
370
- alloc_measure_vmem(&base_addr, &size);
371
-
372
- *alloc = (struct ggml_allocr){
373
- /*.data = */ base_addr,
374
- /*.size = */ size,
375
- /*.alignment = */ alignment,
376
  /*.n_free_blocks = */ 0,
377
  /*.free_blocks = */ {{0}},
378
- /*.hash_table = */ {{0}},
379
  /*.max_size = */ 0,
380
- /*.measure = */ true,
381
- /*.parse_seq = */ {0},
382
- /*.parse_seq_len = */ 0,
383
  #ifdef GGML_ALLOCATOR_DEBUG
384
  /*.allocated_tensors = */ {0},
385
  #endif
386
  };
387
 
388
- ggml_allocr_reset(alloc);
389
 
390
  return alloc;
391
  }
392
 
393
- void ggml_allocr_free(struct ggml_allocr * alloc) {
394
- if (alloc->measure) {
395
- free_measure_vmem(alloc->data, alloc->size);
 
 
 
 
 
 
 
 
396
  }
397
  free(alloc);
398
  }
399
 
400
- bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
401
  return alloc->measure;
402
  }
403
 
404
- //////////// compute graph allocator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
407
  if (a->type != b->type) {
@@ -435,7 +431,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
435
  case GGML_OP_ROPE:
436
  case GGML_OP_RMS_NORM:
437
  case GGML_OP_SOFT_MAX:
438
- case GGML_OP_CONT:
439
  return true;
440
 
441
  default:
@@ -443,12 +438,38 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
443
  }
444
  }
445
 
446
- static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
447
- struct hash_node * ht = alloc->hash_table;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  if (node->data == NULL) {
449
  if (ggml_is_view(node)) {
450
- assert(node->view_src->data != NULL);
451
- node->data = (char *)node->view_src->data + node->view_offs;
452
  } else {
453
  // see if we can reuse a parent's buffer (inplace)
454
  if (ggml_op_can_inplace(node->op)) {
@@ -459,16 +480,16 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
459
  }
460
 
461
  // if the node's data is external, then we cannot re-use it
462
- if (ggml_allocr_is_own(alloc, parent) == false) {
463
  AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
464
  continue;
465
  }
466
 
467
- struct hash_node * p_hn = hash_get(ht, parent);
468
  if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
469
  if (ggml_is_view(parent)) {
470
  struct ggml_tensor * view_src = parent->view_src;
471
- struct hash_node * view_src_hn = hash_get(ht, view_src);
472
  if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
473
  // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
474
  // the parent's data that it will need later (same layout requirement). the problem is that then
@@ -476,158 +497,270 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
476
  // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
477
  // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
478
  AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
479
- node->data = parent->data;
 
 
480
  return;
481
  }
482
  }
483
  else {
484
  AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
485
- node->data = parent->data;
 
 
486
  return;
487
  }
488
  }
489
  }
490
  }
491
- ggml_allocr_alloc(alloc, node);
492
  }
493
  }
494
  }
495
 
496
- static size_t ggml_allocr_alloc_graph_tensors_n(
497
- struct ggml_allocr * alloc,
498
- struct ggml_cgraph ** graphs, int n_graphs,
499
- struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
500
 
501
- // reset hash table
502
- struct hash_node * ht = alloc->hash_table;
503
- memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
 
 
 
504
 
505
  // count number of children and views
506
- for (int g = 0; g < n_graphs; g++) {
507
- struct ggml_cgraph * gf = graphs[g];
508
- for (int i = 0; i < gf->n_nodes; i++) {
509
- struct ggml_tensor * node = gf->nodes[i];
 
 
 
 
 
 
 
510
 
511
- if (ggml_is_view(node)) {
512
- struct ggml_tensor * view_src = node->view_src;
513
- hash_get(ht, view_src)->n_views += 1;
 
514
  }
 
 
 
 
 
 
515
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  for (int j = 0; j < GGML_MAX_SRC; j++) {
517
  struct ggml_tensor * parent = node->src[j];
518
  if (parent == NULL) {
519
  break;
520
  }
521
- hash_get(ht, parent)->n_children += 1;
522
  }
523
- }
524
- }
525
 
526
- // allocate tensors
527
- for (int g = 0; g < n_graphs; g++) {
528
- struct ggml_cgraph * gf = graphs[g];
529
- AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
530
- // graph inputs are allocated first to ensure that they are not overwritten by each other
531
- if (inputs != NULL && inputs[g] != NULL) {
532
- for (int i = 0; inputs[g][i] != NULL; i++) {
533
- struct ggml_tensor * input = inputs[g][i];
534
- AT_PRINTF("input: %s\n", input->name);
535
- allocate_node(alloc, input);
 
 
 
536
  }
 
537
  }
538
- // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
539
- int last_barrier_pos = 0;
540
- int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
541
 
542
- for (int ind = 0; ind < n_nodes; ind++) {
543
- // allocate a node if there is no parse_seq or this is not a barrier
544
- if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
545
- int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
546
- struct ggml_tensor * node = gf->nodes[i];
 
 
 
 
547
 
548
- // allocate parents (leafs)
549
  for (int j = 0; j < GGML_MAX_SRC; j++) {
550
  struct ggml_tensor * parent = node->src[j];
551
  if (parent == NULL) {
552
  break;
553
  }
554
- allocate_node(alloc, parent);
555
- }
556
 
557
- // allocate node
558
- allocate_node(alloc, node);
559
 
560
- AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
561
- for (int j = 0; j < GGML_MAX_SRC; j++) {
562
- struct ggml_tensor * parent = node->src[j];
563
- if (parent == NULL) {
564
- break;
565
- }
566
- AT_PRINTF("%s", parent->name);
567
- if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
568
- AT_PRINTF(", ");
569
- }
570
- }
571
- AT_PRINTF("\n");
572
- }
573
-
574
- // update parents
575
- // update immediately if there is no parse_seq
576
- // update only at barriers if there is parse_seq
577
- if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
578
- int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
579
- int update_end = alloc->parse_seq_len ? ind : ind + 1;
580
- for (int i = update_start; i < update_end; i++) {
581
- int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
582
- struct ggml_tensor * node = gf->nodes[node_i];
583
-
584
- for (int j = 0; j < GGML_MAX_SRC; j++) {
585
- struct ggml_tensor * parent = node->src[j];
586
- if (parent == NULL) {
587
- break;
588
- }
589
- struct hash_node * p_hn = hash_get(ht, parent);
590
- p_hn->n_children -= 1;
591
-
592
- //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
593
-
594
- if (p_hn->n_children == 0 && p_hn->n_views == 0) {
595
- if (ggml_is_view(parent)) {
596
- struct ggml_tensor * view_src = parent->view_src;
597
- struct hash_node * view_src_hn = hash_get(ht, view_src);
598
- view_src_hn->n_views -= 1;
599
- AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
600
- if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
601
- ggml_allocr_free_tensor(alloc, view_src);
602
- }
603
- }
604
- else {
605
- if (parent->data != node->data) {
606
- ggml_allocr_free_tensor(alloc, parent);
607
- }
608
  }
609
  }
 
 
 
610
  }
611
  }
612
- AT_PRINTF("\n");
613
- if (alloc->parse_seq_len) {
614
- last_barrier_pos = ind + 1;
615
- }
616
  }
617
- }
618
- // free graph outputs here that wouldn't be freed otherwise because they have no children
619
- if (outputs != NULL && outputs[g] != NULL) {
620
- for (int i = 0; outputs[g][i] != NULL; i++) {
621
- struct ggml_tensor * output = outputs[g][i];
622
- AT_PRINTF("output: %s\n", output->name);
623
- ggml_allocr_free_tensor(alloc, output);
624
  }
625
  }
626
  }
 
627
 
628
- return alloc->max_size;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  }
630
 
631
- size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
632
- return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
633
  }
 
1
  #include "ggml-alloc.h"
2
+ #include "ggml-backend-impl.h"
3
  #include "ggml.h"
4
+ #include "ggml-impl.h"
5
  #include <assert.h>
6
+ #include <limits.h>
7
  #include <stdarg.h>
8
  #include <stdio.h>
9
  #include <stdlib.h>
10
  #include <string.h>
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
+ #define MAX_FREE_BLOCKS 256
14
 
15
  //#define GGML_ALLOCATOR_DEBUG
16
 
17
+ //#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
18
+ #define AT_PRINTF(...)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  // TODO: GGML_PAD ?
21
  static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
 
29
  size_t size;
30
  };
31
 
32
+ struct ggml_tallocr {
33
+ struct ggml_backend_buffer * buffer;
34
+ bool buffer_owned;
35
+ void * base;
 
36
  size_t alignment;
37
+
38
  int n_free_blocks;
39
  struct free_block free_blocks[MAX_FREE_BLOCKS];
40
+
41
  size_t max_size;
42
+
43
  bool measure;
 
 
44
 
45
  #ifdef GGML_ALLOCATOR_DEBUG
46
  struct ggml_tensor * allocated_tensors[1024];
 
48
  };
49
 
50
  #ifdef GGML_ALLOCATOR_DEBUG
51
+ static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
52
  for (int i = 0; i < 1024; i++) {
53
  if (alloc->allocated_tensors[i] == NULL) {
54
  alloc->allocated_tensors[i] = tensor;
 
57
  }
58
  GGML_ASSERT(!"out of allocated_tensors");
59
  }
60
+ static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
61
  for (int i = 0; i < 1024; i++) {
62
  if (alloc->allocated_tensors[i] == tensor ||
63
  (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
 
70
  }
71
  #endif
72
 
 
 
 
 
 
 
73
  // check if a tensor is allocated by this buffer
74
+ static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
75
+ return tensor->buffer == alloc->buffer;
 
76
  }
77
 
78
  static bool ggml_is_view(struct ggml_tensor * t) {
79
  return t->view_src != NULL;
80
  }
81
 
82
+ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
 
83
  GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
84
  GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
85
+
86
+ size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
87
  size = aligned_offset(NULL, size, alloc->alignment);
88
 
89
  AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
 
130
  }
131
 
132
  tensor->data = addr;
133
+ tensor->buffer = alloc->buffer;
134
+ if (!alloc->measure) {
135
+ ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
136
+ }
137
 
138
  #ifdef GGML_ALLOCATOR_DEBUG
139
  add_allocated_tensor(alloc, tensor);
 
149
  }
150
  #endif
151
 
152
+ alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
153
  }
154
 
155
  // this is a very naive implementation, but for our case the number of free blocks should be very small
156
+ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
157
+ if (ggml_tallocr_is_own(alloc, tensor) == false) {
 
 
158
  // the tensor was not allocated in this buffer
159
  // this can happen because the graph allocator will try to free weights and other tensors from different buffers
160
  // the easiest way to deal with this is just to ignore it
161
+ // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
162
  return;
163
  }
164
 
165
+ void * ptr = tensor->data;
166
+
167
+ size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
168
  size = aligned_offset(NULL, size, alloc->alignment);
169
+ AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
170
+
171
+ if (!alloc->measure) {
172
+ ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
173
+ }
174
 
175
  #ifdef GGML_ALLOCATOR_DEBUG
176
  remove_allocated_tensor(alloc, tensor);
 
224
  alloc->n_free_blocks++;
225
  }
226
 
227
+ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
228
+ alloc->n_free_blocks = 1;
229
+ size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
230
+ alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
231
+
232
+ if (alloc->measure) {
233
+ alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
234
+ } else {
235
+ alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
236
  }
 
237
  }
238
 
239
+ ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
240
+ struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
 
 
 
 
241
 
242
+ ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
243
 
244
+ *alloc = (struct ggml_tallocr) {
245
+ /*.buffer = */ buffer,
246
+ /*.buffer_owned = */ true,
247
+ /*.base = */ ggml_backend_buffer_get_base(buffer),
248
  /*.alignment = */ alignment,
249
  /*.n_free_blocks = */ 0,
250
  /*.free_blocks = */ {{0}},
 
251
  /*.max_size = */ 0,
252
  /*.measure = */ false,
 
 
253
  #ifdef GGML_ALLOCATOR_DEBUG
254
  /*.allocated_tensors = */ {0},
255
  #endif
256
  };
257
 
258
+ ggml_tallocr_reset(alloc);
259
 
260
  return alloc;
261
  }
262
 
263
+ ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
264
+ ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
265
+ alloc->measure = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ return alloc;
 
 
 
 
 
 
 
 
 
 
268
  }
269
 
270
+ ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
271
+ // create a backend buffer to get the correct tensor allocation sizes
272
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1);
 
 
 
 
 
 
 
 
 
 
273
 
274
+ // TODO: move alloc initialization to a common ggml_tallocr_new_impl function
275
+ ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
276
+ alloc->buffer_owned = true;
277
+ alloc->measure = true;
278
+ ggml_tallocr_reset(alloc);
279
+ return alloc;
280
  }
281
 
282
+ ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
283
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
284
+ ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
285
+ alloc->buffer_owned = true;
286
+ return alloc;
287
  }
288
 
289
+ ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
290
+ ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
291
 
292
+ *alloc = (struct ggml_tallocr) {
293
+ /*.buffer = */ buffer,
294
+ /*.buffer_owned = */ false,
295
+ /*.base = */ ggml_backend_buffer_get_base(buffer),
296
+ /*.alignment = */ ggml_backend_buffer_get_alignment(buffer),
 
 
 
 
297
  /*.n_free_blocks = */ 0,
298
  /*.free_blocks = */ {{0}},
 
299
  /*.max_size = */ 0,
300
+ /*.measure = */ false,
 
 
301
  #ifdef GGML_ALLOCATOR_DEBUG
302
  /*.allocated_tensors = */ {0},
303
  #endif
304
  };
305
 
306
+ ggml_tallocr_reset(alloc);
307
 
308
  return alloc;
309
  }
310
 
311
+ struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
312
+ return alloc->buffer;
313
+ }
314
+
315
+ void ggml_tallocr_free(ggml_tallocr_t alloc) {
316
+ if (alloc == NULL) {
317
+ return;
318
+ }
319
+
320
+ if (alloc->buffer_owned) {
321
+ ggml_backend_buffer_free(alloc->buffer);
322
  }
323
  free(alloc);
324
  }
325
 
326
+ bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
327
  return alloc->measure;
328
  }
329
 
330
+ size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
331
+ return alloc->max_size;
332
+ }
333
+
334
+ // graph allocator
335
+
336
+ struct hash_node {
337
+ int n_children;
338
+ int n_views;
339
+ };
340
+
341
+ struct ggml_gallocr {
342
+ ggml_tallocr_t talloc;
343
+ struct ggml_hash_set hash_set;
344
+ struct hash_node * hash_values;
345
+ size_t hash_values_size;
346
+ ggml_tallocr_t * hash_allocs;
347
+ int * parse_seq;
348
+ int parse_seq_len;
349
+ };
350
+
351
+ ggml_gallocr_t ggml_gallocr_new(void) {
352
+ ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr));
353
+
354
+ *galloc = (struct ggml_gallocr) {
355
+ /*.talloc = */ NULL,
356
+ /*.hash_set = */ {0},
357
+ /*.hash_values = */ NULL,
358
+ /*.hash_values_size = */ 0,
359
+ /*.hash_allocs = */ NULL,
360
+ /*.parse_seq = */ NULL,
361
+ /*.parse_seq_len = */ 0,
362
+ };
363
+
364
+ return galloc;
365
+ }
366
+
367
+ void ggml_gallocr_free(ggml_gallocr_t galloc) {
368
+ if (galloc == NULL) {
369
+ return;
370
+ }
371
+
372
+ if (galloc->hash_set.keys != NULL) {
373
+ free(galloc->hash_set.keys);
374
+ }
375
+ if (galloc->hash_values != NULL) {
376
+ free(galloc->hash_values);
377
+ }
378
+ if (galloc->hash_allocs != NULL) {
379
+ free(galloc->hash_allocs);
380
+ }
381
+ if (galloc->parse_seq != NULL) {
382
+ free(galloc->parse_seq);
383
+ }
384
+ free(galloc);
385
+ }
386
+
387
+ void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) {
388
+ free(galloc->parse_seq);
389
+ galloc->parse_seq = malloc(sizeof(int) * n);
390
+
391
+ for (int i = 0; i < n; i++) {
392
+ galloc->parse_seq[i] = list[i];
393
+ }
394
+ galloc->parse_seq_len = n;
395
+ }
396
+
397
+ static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
398
+ size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
399
+ return &galloc->hash_values[i];
400
+ }
401
 
402
  static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
403
  if (a->type != b->type) {
 
431
  case GGML_OP_ROPE:
432
  case GGML_OP_RMS_NORM:
433
  case GGML_OP_SOFT_MAX:
 
434
  return true;
435
 
436
  default:
 
438
  }
439
  }
440
 
441
+ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) {
442
+ if (galloc->talloc != NULL) {
443
+ return galloc->talloc;
444
+ }
445
+
446
+ return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
447
+ }
448
+
449
+ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view) {
450
+ ggml_tallocr_t alloc = node_tallocr(galloc, view);
451
+
452
+ //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
453
+ GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
454
+ view->backend = view->view_src->backend;
455
+ view->buffer = view->view_src->buffer;
456
+ view->data = (char *)view->view_src->data + view->view_offs;
457
+
458
+ // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
459
+ // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
460
+ assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
461
+
462
+ if (!alloc->measure) {
463
+ ggml_backend_buffer_init_tensor(alloc->buffer, view);
464
+ }
465
+ }
466
+
467
+ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
468
+ ggml_tallocr_t alloc = node_tallocr(galloc, node);
469
+
470
  if (node->data == NULL) {
471
  if (ggml_is_view(node)) {
472
+ init_view(galloc, node);
 
473
  } else {
474
  // see if we can reuse a parent's buffer (inplace)
475
  if (ggml_op_can_inplace(node->op)) {
 
480
  }
481
 
482
  // if the node's data is external, then we cannot re-use it
483
+ if (ggml_tallocr_is_own(alloc, parent) == false) {
484
  AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
485
  continue;
486
  }
487
 
488
+ struct hash_node * p_hn = hash_get(galloc, parent);
489
  if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
490
  if (ggml_is_view(parent)) {
491
  struct ggml_tensor * view_src = parent->view_src;
492
+ struct hash_node * view_src_hn = hash_get(galloc, view_src);
493
  if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
494
  // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
495
  // the parent's data that it will need later (same layout requirement). the problem is that then
 
497
  // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
498
  // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
499
  AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
500
+ node->view_src = view_src;
501
+ view_src_hn->n_views += 1;
502
+ init_view(galloc, node);
503
  return;
504
  }
505
  }
506
  else {
507
  AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
508
+ node->view_src = parent;
509
+ p_hn->n_views += 1;
510
+ init_view(galloc, node);
511
  return;
512
  }
513
  }
514
  }
515
  }
516
+ ggml_tallocr_alloc(alloc, node);
517
  }
518
  }
519
  }
520
 
521
+ static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
522
+ ggml_tallocr_t alloc = node_tallocr(galloc, node);
 
 
523
 
524
+ ggml_tallocr_free_tensor(alloc, node);
525
+ }
526
+
527
+ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) {
528
+ const int * parse_seq = galloc->parse_seq;
529
+ int parse_seq_len = galloc->parse_seq_len;
530
 
531
  // count number of children and views
532
+ for (int i = 0; i < gf->n_nodes; i++) {
533
+ struct ggml_tensor * node = gf->nodes[i];
534
+
535
+ if (ggml_is_view(node)) {
536
+ struct ggml_tensor * view_src = node->view_src;
537
+ hash_get(galloc, view_src)->n_views += 1;
538
+ if (node->buffer == NULL && node->data != NULL) {
539
+ // view of a pre-allocated tensor, didn't call init_view() yet
540
+ init_view(galloc, node);
541
+ }
542
+ }
543
 
544
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
545
+ struct ggml_tensor * parent = node->src[j];
546
+ if (parent == NULL) {
547
+ break;
548
  }
549
+ hash_get(galloc, parent)->n_children += 1;
550
+ if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
551
+ init_view(galloc, parent);
552
+ }
553
+ }
554
+ }
555
 
556
+ // allocate tensors
557
+ // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
558
+ int last_barrier_pos = 0;
559
+ int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
560
+
561
+ for (int ind = 0; ind < n_nodes; ind++) {
562
+ // allocate a node if there is no parse_seq or this is not a barrier
563
+ if (parse_seq_len == 0 || parse_seq[ind] != -1) {
564
+ int i = parse_seq_len ? parse_seq[ind] : ind;
565
+ struct ggml_tensor * node = gf->nodes[i];
566
+
567
+ // allocate parents (leafs)
568
  for (int j = 0; j < GGML_MAX_SRC; j++) {
569
  struct ggml_tensor * parent = node->src[j];
570
  if (parent == NULL) {
571
  break;
572
  }
573
+ allocate_node(galloc, parent);
574
  }
 
 
575
 
576
+ // allocate node
577
+ allocate_node(galloc, node);
578
+
579
+ AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
580
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
581
+ struct ggml_tensor * parent = node->src[j];
582
+ if (parent == NULL) {
583
+ break;
584
+ }
585
+ AT_PRINTF("%s", parent->name);
586
+ if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
587
+ AT_PRINTF(", ");
588
+ }
589
  }
590
+ AT_PRINTF("\n");
591
  }
 
 
 
592
 
593
+ // update parents
594
+ // update immediately if there is no parse_seq
595
+ // update only at barriers if there is parse_seq
596
+ if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
597
+ int update_start = parse_seq_len ? last_barrier_pos : ind;
598
+ int update_end = parse_seq_len ? ind : ind + 1;
599
+ for (int i = update_start; i < update_end; i++) {
600
+ int node_i = parse_seq_len ? parse_seq[i] : i;
601
+ struct ggml_tensor * node = gf->nodes[node_i];
602
 
 
603
  for (int j = 0; j < GGML_MAX_SRC; j++) {
604
  struct ggml_tensor * parent = node->src[j];
605
  if (parent == NULL) {
606
  break;
607
  }
608
+ struct hash_node * p_hn = hash_get(galloc, parent);
609
+ p_hn->n_children -= 1;
610
 
611
+ //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
 
612
 
613
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
614
+ if (ggml_is_view(parent)) {
615
+ struct ggml_tensor * view_src = parent->view_src;
616
+ struct hash_node * view_src_hn = hash_get(galloc, view_src);
617
+ view_src_hn->n_views -= 1;
618
+ AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
619
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
620
+ free_node(galloc, view_src);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  }
622
  }
623
+ else {
624
+ free_node(galloc, parent);
625
+ }
626
  }
627
  }
 
 
 
 
628
  }
629
+ AT_PRINTF("\n");
630
+ if (parse_seq_len) {
631
+ last_barrier_pos = ind + 1;
 
 
 
 
632
  }
633
  }
634
  }
635
+ }
636
 
637
+ size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) {
638
+ size_t hash_size = graph->visited_hash_table.size;
639
+
640
+ // check if the hash table is initialized and large enough
641
+ if (galloc->hash_set.size < hash_size) {
642
+ if (galloc->hash_set.keys != NULL) {
643
+ free(galloc->hash_set.keys);
644
+ }
645
+ if (galloc->hash_values != NULL) {
646
+ free(galloc->hash_values);
647
+ }
648
+ galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size);
649
+ galloc->hash_set.size = hash_size;
650
+ galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
651
+ }
652
+
653
+ // reset hash table
654
+ memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size);
655
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
656
+
657
+ galloc->talloc = talloc;
658
+ ggml_tallocr_alloc_graph_impl(galloc, graph);
659
+ galloc->talloc = NULL;
660
+
661
+ size_t max_size = ggml_tallocr_max_size(talloc);
662
+
663
+ return max_size;
664
+ }
665
+
666
+ void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_alloct) {
667
+ const size_t hash_size = hash_set.size;
668
+
669
+ GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
670
+
671
+ galloc->talloc = NULL;
672
+
673
+ // alloc hash_values if needed
674
+ if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
675
+ free(galloc->hash_values);
676
+ galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
677
+ galloc->hash_values_size = hash_size;
678
+ }
679
+
680
+ // free hash_set.keys if needed
681
+ if (galloc->hash_set.keys != NULL) {
682
+ free(galloc->hash_set.keys);
683
+ }
684
+ galloc->hash_set = hash_set;
685
+
686
+ // reset hash values
687
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
688
+
689
+ galloc->hash_allocs = hash_node_alloct;
690
+
691
+ ggml_tallocr_alloc_graph_impl(galloc, graph);
692
+
693
+ // remove unowned resources
694
+ galloc->hash_set.keys = NULL;
695
+ galloc->hash_allocs = NULL;
696
+ }
697
+
698
+ // legacy API wrapper
699
+
700
+ struct ggml_allocr {
701
+ ggml_tallocr_t talloc;
702
+ ggml_gallocr_t galloc;
703
+ };
704
+
705
+ static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) {
706
+ ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr));
707
+ *alloc = (struct ggml_allocr) {
708
+ /*.talloc = */ talloc,
709
+ /*.galloc = */ ggml_gallocr_new(),
710
+ };
711
+ return alloc;
712
+ }
713
+
714
+ ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) {
715
+ return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment));
716
+ }
717
+
718
+ ggml_allocr_t ggml_allocr_new_measure(size_t alignment) {
719
+ return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment));
720
+ }
721
+
722
+ ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
723
+ return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer));
724
+ }
725
+
726
+ ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) {
727
+ return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size));
728
+ }
729
+
730
+ ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) {
731
+ return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend));
732
+ }
733
+
734
+ struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) {
735
+ return ggml_tallocr_get_buffer(alloc->talloc);
736
+ }
737
+
738
+ void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
739
+ ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
740
+ }
741
+
742
+ void ggml_allocr_free(ggml_allocr_t alloc) {
743
+ ggml_gallocr_free(alloc->galloc);
744
+ ggml_tallocr_free(alloc->talloc);
745
+ free(alloc);
746
+ }
747
+
748
+ bool ggml_allocr_is_measure(ggml_allocr_t alloc) {
749
+ return ggml_tallocr_is_measure(alloc->talloc);
750
+ }
751
+
752
+ void ggml_allocr_reset(ggml_allocr_t alloc) {
753
+ ggml_tallocr_reset(alloc->talloc);
754
+ }
755
+
756
+ void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) {
757
+ ggml_tallocr_alloc(alloc->talloc, tensor);
758
+ }
759
+
760
+ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
761
+ return ggml_tallocr_max_size(alloc->talloc);
762
  }
763
 
764
+ size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
765
+ return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
766
  }
ggml-alloc.h CHANGED
@@ -6,20 +6,79 @@
6
  extern "C" {
7
  #endif
8
 
 
 
9
 
10
- GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
11
- GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  // tell the allocator to parse nodes following the order described in the list
14
  // you should call this if your graph are optimized to execute out-of-order
15
- GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
18
- GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
19
- GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
20
- GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
21
- GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
22
 
 
 
 
 
 
 
23
 
24
  #ifdef __cplusplus
25
  }
 
6
  extern "C" {
7
  #endif
8
 
9
+ struct ggml_backend;
10
+ struct ggml_backend_buffer;
11
 
12
+ //
13
+ // Legacy API
14
+ //
15
+
16
+ typedef struct ggml_allocr * ggml_allocr_t;
17
+
18
+ // initialize allocator for use with CPU backend only
19
+ GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
20
+ GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
21
+
22
+ // initialize allocator for use with ggml-backend
23
+ GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
24
+ GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
25
+ GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
26
+
27
+ GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
28
 
29
  // tell the allocator to parse nodes following the order described in the list
30
  // you should call this if your graph are optimized to execute out-of-order
31
+ GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
32
+
33
+ GGML_API void ggml_allocr_free (ggml_allocr_t alloc);
34
+ GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc);
35
+ GGML_API void ggml_allocr_reset (ggml_allocr_t alloc);
36
+ GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor);
37
+ GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc);
38
+
39
+ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
40
+
41
+ //
42
+ // ggml-backend v2 API
43
+ //
44
+
45
+ // Seperate tensor and graph allocator objects
46
+ // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
47
+ // The original API is kept as a wrapper around the new API
48
+
49
+ // Tensor allocator
50
+ typedef struct ggml_tallocr * ggml_tallocr_t;
51
+
52
+ GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
53
+ GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
54
+ GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
55
+ GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
56
+ GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
57
+
58
+ GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
59
+
60
+ GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc);
61
+ GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc);
62
+ GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc);
63
+ GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
64
+ GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc);
65
+
66
+
67
+ // Graph allocator
68
+ typedef struct ggml_gallocr * ggml_gallocr_t;
69
+
70
+ GGML_API ggml_gallocr_t ggml_gallocr_new(void);
71
+ GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
72
 
73
+ GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
74
+ GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
 
 
 
75
 
76
+ // Allocate tensors from the allocators given by the hash table
77
+ GGML_API void ggml_gallocr_alloc_graph_n(
78
+ ggml_gallocr_t galloc,
79
+ struct ggml_cgraph * graph,
80
+ struct ggml_hash_set hash_set,
81
+ ggml_tallocr_t * hash_node_talloc);
82
 
83
  #ifdef __cplusplus
84
  }
ggml-backend-impl.h ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ggml-backend internal header
4
+
5
+ #include "ggml-backend.h"
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ //
12
+ // Backend buffer
13
+ //
14
+
15
+ typedef void * ggml_backend_buffer_context_t;
16
+
17
+ struct ggml_backend_buffer_i {
18
+ void (*free_buffer) (ggml_backend_buffer_t buffer);
19
+ void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
20
+ size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
21
+ void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
22
+ void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
23
+ };
24
+
25
+ struct ggml_backend_buffer {
26
+ struct ggml_backend_buffer_i iface;
27
+
28
+ ggml_backend_t backend;
29
+ ggml_backend_buffer_context_t context;
30
+
31
+ size_t size;
32
+ };
33
+
34
+ GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
35
+ struct ggml_backend * backend,
36
+ struct ggml_backend_buffer_i iface,
37
+ ggml_backend_buffer_context_t context,
38
+ size_t size);
39
+
40
+ //
41
+ // Backend
42
+ //
43
+
44
+ typedef void * ggml_backend_context_t;
45
+
46
+ struct ggml_backend_i {
47
+ const char * (*get_name)(ggml_backend_t backend);
48
+
49
+ void (*free)(ggml_backend_t backend);
50
+
51
+ // buffer allocation
52
+ ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
53
+
54
+ // get buffer alignment
55
+ size_t (*get_alignment)(ggml_backend_t backend);
56
+
57
+ // tensor data access
58
+ // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
59
+ void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
60
+ void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
61
+ void (*synchronize) (ggml_backend_t backend);
62
+
63
+ // (optional) copy tensor between different backends, allow for single-copy tranfers
64
+ void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
65
+ void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
66
+
67
+ // compute graph with a plan
68
+ ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
69
+ void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
70
+ void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
71
+
72
+ // compute graph without a plan
73
+ void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
74
+
75
+ // check if the backend supports an operation
76
+ bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
77
+ };
78
+
79
+ struct ggml_backend {
80
+ struct ggml_backend_i iface;
81
+
82
+ ggml_backend_context_t context;
83
+ };
84
+
85
+ #ifdef __cplusplus
86
+ }
87
+ #endif
ggml-backend.c ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-backend-impl.h"
2
+ #include "ggml-alloc.h"
3
+ #include "ggml-impl.h"
4
+
5
+ #include <assert.h>
6
+ #include <limits.h>
7
+ #include <stdarg.h>
8
+ #include <stdio.h>
9
+ #include <stdlib.h>
10
+ #include <string.h>
11
+
12
+ #define UNUSED GGML_UNUSED
13
+
14
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
15
+
16
+ // backend buffer
17
+
18
+ ggml_backend_buffer_t ggml_backend_buffer_init(
19
+ struct ggml_backend * backend,
20
+ struct ggml_backend_buffer_i iface,
21
+ ggml_backend_buffer_context_t context,
22
+ size_t size) {
23
+ ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
24
+
25
+ GGML_ASSERT(iface.get_base != NULL);
26
+
27
+ (*buffer) = (struct ggml_backend_buffer) {
28
+ /* .interface = */ iface,
29
+ /* .backend = */ backend,
30
+ /* .context = */ context,
31
+ /* .size = */ size,
32
+ };
33
+
34
+ return buffer;
35
+ }
36
+
37
+ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
38
+ if (buffer == NULL) {
39
+ return;
40
+ }
41
+
42
+ if (buffer->iface.free_buffer != NULL) {
43
+ buffer->iface.free_buffer(buffer);
44
+ }
45
+ free(buffer);
46
+ }
47
+
48
+ size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
49
+ return ggml_backend_get_alignment(buffer->backend);
50
+ }
51
+
52
+ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
53
+ return buffer->size;
54
+ }
55
+
56
+ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
57
+ void * base = buffer->iface.get_base(buffer);
58
+
59
+ GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
60
+
61
+ return base;
62
+ }
63
+
64
+ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
65
+ // get_alloc_size is optional, defaults to ggml_nbytes
66
+ if (buffer->iface.get_alloc_size) {
67
+ return buffer->iface.get_alloc_size(buffer, tensor);
68
+ }
69
+ return ggml_nbytes(tensor);
70
+ }
71
+
72
+ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
73
+ // init_tensor is optional
74
+ if (buffer->iface.init_tensor) {
75
+ buffer->iface.init_tensor(buffer, tensor);
76
+ }
77
+ }
78
+
79
+ void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
80
+ // free_tensor is optional
81
+ if (buffer->iface.free_tensor) {
82
+ buffer->iface.free_tensor(buffer, tensor);
83
+ }
84
+ }
85
+
86
+ // backend
87
+
88
+ ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
89
+ return tensor->buffer ? tensor->buffer->backend : NULL;
90
+ }
91
+
92
+ const char * ggml_backend_name(ggml_backend_t backend) {
93
+ if (backend == NULL) {
94
+ return "NULL";
95
+ }
96
+ return backend->iface.get_name(backend);
97
+ }
98
+
99
+ void ggml_backend_free(ggml_backend_t backend) {
100
+ if (backend == NULL) {
101
+ return;
102
+ }
103
+
104
+ backend->iface.free(backend);
105
+ }
106
+
107
+ ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
108
+ return backend->iface.alloc_buffer(backend, size);
109
+ }
110
+
111
+ size_t ggml_backend_get_alignment(ggml_backend_t backend) {
112
+ return backend->iface.get_alignment(backend);
113
+ }
114
+
115
+ void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
116
+ ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
117
+ }
118
+
119
+ void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
120
+ ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
121
+ }
122
+
123
+ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
124
+ ggml_backend_t backend = ggml_get_backend(tensor);
125
+
126
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
127
+ GGML_ASSERT(backend != NULL && "tensor backend not set");
128
+
129
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
130
+ backend->iface.synchronize(backend);
131
+ }
132
+
133
+ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
134
+ ggml_backend_t backend = ggml_get_backend(tensor);
135
+
136
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
137
+ GGML_ASSERT(backend != NULL && "tensor backend not set");
138
+
139
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
140
+ backend->iface.synchronize(backend);
141
+ }
142
+
143
+ void ggml_backend_synchronize(ggml_backend_t backend) {
144
+ backend->iface.synchronize(backend);
145
+ }
146
+
147
+ ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
148
+ return backend->iface.graph_plan_create(backend, cgraph);
149
+ }
150
+
151
+ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
152
+ backend->iface.graph_plan_free(backend, plan);
153
+ }
154
+
155
+ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
156
+ backend->iface.graph_plan_compute(backend, plan);
157
+ }
158
+
159
+ void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
160
+ backend->iface.graph_compute(backend, cgraph);
161
+ }
162
+
163
+ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
164
+ return backend->iface.supports_op(backend, op);
165
+ }
166
+
167
+ // backend copy
168
+
169
+ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
170
+ if (a->type != b->type) {
171
+ return false;
172
+ }
173
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
174
+ if (a->ne[i] != b->ne[i]) {
175
+ return false;
176
+ }
177
+ if (a->nb[i] != b->nb[i]) {
178
+ return false;
179
+ }
180
+ }
181
+ return true;
182
+ }
183
+
184
+ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
185
+ //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]);
186
+ //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
187
+ GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
188
+
189
+ // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
190
+
191
+ if (src == dst) {
192
+ return;
193
+ }
194
+
195
+ // TODO: allow backends to support copy to/from same backend
196
+
197
+ if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
198
+ ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst);
199
+ } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
200
+ ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
201
+ } else {
202
+ // shouldn't be hit when copying from/to CPU
203
+ #ifndef NDEBUG
204
+ fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
205
+ #endif
206
+ size_t nbytes = ggml_nbytes(src);
207
+ void * data = malloc(nbytes);
208
+ ggml_backend_tensor_get(src, data, 0, nbytes);
209
+ ggml_backend_tensor_set(dst, data, 0, nbytes);
210
+ free(data);
211
+ }
212
+ }
213
+
214
+ // backend CPU
215
+
216
+ struct ggml_backend_cpu_context {
217
+ int n_threads;
218
+ void * work_data;
219
+ size_t work_size;
220
+ };
221
+
222
+ static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
223
+ return "CPU";
224
+
225
+ UNUSED(backend);
226
+ }
227
+
228
+ static void ggml_backend_cpu_free(ggml_backend_t backend) {
229
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
230
+ free(cpu_ctx->work_data);
231
+ free(cpu_ctx);
232
+ free(backend);
233
+ }
234
+
235
+ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
236
+ return (void *)buffer->context;
237
+ }
238
+
239
+ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
240
+ free(buffer->context);
241
+ UNUSED(buffer);
242
+ }
243
+
244
+ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
245
+ /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
246
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
247
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
248
+ /* .init_tensor = */ NULL, // no initialization required
249
+ /* .free_tensor = */ NULL, // no cleanup required
250
+ };
251
+
252
+ // for buffers from ptr, free is not called
253
+ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
254
+ /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
255
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
256
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
257
+ /* .init_tensor = */ NULL,
258
+ /* .free_tensor = */ NULL,
259
+ };
260
+
261
+ static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
262
+
263
+ static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) {
264
+ size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
265
+ void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
266
+
267
+ GGML_ASSERT(data != NULL && "failed to allocate buffer");
268
+
269
+ return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
270
+ }
271
+
272
+ static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
273
+ return TENSOR_ALIGNMENT;
274
+ UNUSED(backend);
275
+ }
276
+
277
+ static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
278
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
279
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
280
+
281
+ memcpy((char *)tensor->data + offset, data, size);
282
+
283
+ UNUSED(backend);
284
+ }
285
+
286
+ static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
287
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
288
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
289
+
290
+ memcpy(data, (const char *)tensor->data + offset, size);
291
+
292
+ UNUSED(backend);
293
+ }
294
+
295
+ static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
296
+ UNUSED(backend);
297
+ }
298
+
299
+ static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
300
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
301
+
302
+ UNUSED(backend);
303
+ }
304
+
305
+ static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
306
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
307
+
308
+ UNUSED(backend);
309
+ }
310
+
311
+ struct ggml_backend_plan_cpu {
312
+ struct ggml_cplan cplan;
313
+ struct ggml_cgraph cgraph;
314
+ };
315
+
316
+ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
317
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
318
+
319
+ struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
320
+
321
+ cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
322
+ cpu_plan->cgraph = *cgraph;
323
+
324
+ if (cpu_plan->cplan.work_size > 0) {
325
+ cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
326
+ }
327
+
328
+ return cpu_plan;
329
+ }
330
+
331
+ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
332
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
333
+
334
+ free(cpu_plan->cplan.work_data);
335
+ free(cpu_plan);
336
+
337
+ UNUSED(backend);
338
+ }
339
+
340
+ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
341
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
342
+
343
+ ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
344
+
345
+ UNUSED(backend);
346
+ }
347
+
348
+ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
349
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
350
+
351
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
352
+
353
+ if (cpu_ctx->work_size < cplan.work_size) {
354
+ // TODO: may be faster to free and use malloc to avoid the copy
355
+ cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
356
+ cpu_ctx->work_size = cplan.work_size;
357
+ }
358
+
359
+ cplan.work_data = cpu_ctx->work_data;
360
+
361
+ ggml_graph_compute(cgraph, &cplan);
362
+ }
363
+
364
+ static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
365
+ return true;
366
+ UNUSED(backend);
367
+ UNUSED(op);
368
+ }
369
+
370
+ static struct ggml_backend_i cpu_backend_i = {
371
+ /* .get_name = */ ggml_backend_cpu_name,
372
+ /* .free = */ ggml_backend_cpu_free,
373
+ /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer,
374
+ /* .get_alignment = */ ggml_backend_cpu_get_alignment,
375
+ /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async,
376
+ /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async,
377
+ /* .synchronize = */ ggml_backend_cpu_synchronize,
378
+ /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from,
379
+ /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to,
380
+ /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
381
+ /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
382
+ /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
383
+ /* .graph_compute = */ ggml_backend_cpu_graph_compute,
384
+ /* .supports_op = */ ggml_backend_cpu_supports_op,
385
+ };
386
+
387
+ ggml_backend_t ggml_backend_cpu_init(void) {
388
+ struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
389
+
390
+ ctx->n_threads = GGML_DEFAULT_N_THREADS;
391
+ ctx->work_data = NULL;
392
+ ctx->work_size = 0;
393
+
394
+ ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
395
+
396
+ *cpu_backend = (struct ggml_backend) {
397
+ /* .interface = */ cpu_backend_i,
398
+ /* .context = */ ctx
399
+ };
400
+ return cpu_backend;
401
+ }
402
+
403
+ bool ggml_backend_is_cpu(ggml_backend_t backend) {
404
+ return backend->iface.get_name == ggml_backend_cpu_name;
405
+ }
406
+
407
+ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
408
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
409
+
410
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
411
+ ctx->n_threads = n_threads;
412
+ }
413
+
414
+ ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
415
+ return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
416
+ }
417
+
418
+ // scheduler
419
+
420
+ #define GGML_MAX_BACKENDS 4
421
+ #define GGML_MAX_SPLITS 256
422
+ #define GGML_MAX_SPLIT_INPUTS 16
423
+
424
+ struct ggml_backend_sched_split {
425
+ ggml_tallocr_t tallocr;
426
+ int i_start;
427
+ int i_end;
428
+ struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
429
+ int n_inputs;
430
+ struct ggml_cgraph * graph;
431
+ };
432
+
433
+ struct ggml_backend_sched {
434
+ int n_backends;
435
+ ggml_backend_t backends[GGML_MAX_BACKENDS];
436
+ ggml_tallocr_t tallocs[GGML_MAX_BACKENDS];
437
+
438
+ ggml_gallocr_t galloc;
439
+
440
+ struct ggml_hash_set hash_set;
441
+ ggml_tallocr_t * node_talloc; // [hash_set.size]
442
+ struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
443
+
444
+ struct ggml_cgraph * graph;
445
+ struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
446
+ int n_splits;
447
+
448
+ struct ggml_context * ctx;
449
+
450
+ // align context_buffer to GGML_MEM_ALIGN
451
+ #ifdef _MSC_VER
452
+ __declspec(align(GGML_MEM_ALIGN))
453
+ #else
454
+ __attribute__((aligned(GGML_MEM_ALIGN)))
455
+ #endif
456
+ char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
457
+ };
458
+
459
+ #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
460
+ #define node_allocr(node) sched->node_talloc[hash_id(node)]
461
+
462
+ static bool ggml_is_view_op(enum ggml_op op) {
463
+ return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
464
+ }
465
+
466
+ // returns the priority of the backend, lower is better
467
+ static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
468
+ for (int i = 0; i < sched->n_backends; i++) {
469
+ if (sched->backends[i] == backend) {
470
+ return i;
471
+ }
472
+ }
473
+ return INT_MAX;
474
+ }
475
+
476
+ static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
477
+ for (int i = 0; i < sched->n_backends; i++) {
478
+ if (sched->tallocs[i] == allocr) {
479
+ return i;
480
+ }
481
+ }
482
+ return INT_MAX;
483
+ }
484
+
485
+ // returns the backend that should be used for the node based on the current locations
486
+ char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
487
+ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
488
+ // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
489
+ // ie. kv cache updates
490
+ // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
491
+ // dst
492
+ ggml_backend_t cur_backend = ggml_get_backend(node);
493
+ if (cur_backend != NULL) {
494
+ sprintf(causes[hash_id(node)], "1.dst");
495
+ return cur_backend;
496
+ }
497
+
498
+ // view_src
499
+ if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
500
+ sprintf(causes[hash_id(node)], "1.vsrc");
501
+ return ggml_get_backend(node->view_src);
502
+ }
503
+
504
+ // src
505
+ int cur_prio = INT_MAX;
506
+ size_t cur_size = 0;
507
+
508
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
509
+ const struct ggml_tensor * src = node->src[i];
510
+ if (src == NULL) {
511
+ break;
512
+ }
513
+ ggml_backend_t src_backend = ggml_get_backend(src);
514
+ if (src_backend != NULL) {
515
+ int src_prio = sched_backend_prio(sched, src_backend);
516
+ size_t src_size = ggml_nbytes(src);
517
+ if (src_prio < cur_prio && src_size >= cur_size) {
518
+ cur_prio = src_prio;
519
+ cur_size = src_size;
520
+ cur_backend = src_backend;
521
+ sprintf(causes[hash_id(node)], "1.src%d", i);
522
+ }
523
+ }
524
+ }
525
+ return cur_backend;
526
+ }
527
+
528
+ static char * fmt_size(size_t size) {
529
+ static char buffer[128];
530
+ if (size >= 1024*1024) {
531
+ sprintf(buffer, "%zuM", size/1024/1024);
532
+ } else {
533
+ sprintf(buffer, "%zuK", size/1024);
534
+ }
535
+ return buffer;
536
+ }
537
+
538
+ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
539
+ int cur_split = 0;
540
+ for (int i = 0; i < graph->n_nodes; i++) {
541
+ if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
542
+ ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
543
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
544
+ for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
545
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
546
+ }
547
+ fprintf(stderr, "\n");
548
+ cur_split++;
549
+ }
550
+ struct ggml_tensor * node = graph->nodes[i];
551
+ if (ggml_is_view_op(node->op)) {
552
+ continue;
553
+ }
554
+ ggml_tallocr_t node_allocr = node_allocr(node);
555
+ ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
556
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
557
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
558
+ struct ggml_tensor * src = node->src[j];
559
+ if (src == NULL) {
560
+ break;
561
+ }
562
+ ggml_tallocr_t src_allocr = node_allocr(src);
563
+ ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
564
+ fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
565
+ }
566
+ fprintf(stderr, "\n");
567
+ }
568
+ }
569
+
570
+ // creates a copy of the tensor with the same memory layout
571
+ static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
572
+ struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
573
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
574
+ dup->nb[i] = tensor->nb[i];
575
+ }
576
+ return dup;
577
+ }
578
+
579
+ // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
580
+ // TODO: merge passes
581
+ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
582
+ // reset state
583
+ size_t hash_size = sched->hash_set.size;
584
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
585
+ memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
586
+ memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
587
+ sched->n_splits = 0;
588
+
589
+ struct ggml_init_params params = {
590
+ /*.mem_size = */ sizeof(sched->context_buffer),
591
+ /*.mem_buffer = */ sched->context_buffer,
592
+ /*.no_alloc = */ true
593
+ };
594
+
595
+ if (sched->ctx != NULL) {
596
+ ggml_free(sched->ctx);
597
+ }
598
+
599
+ sched->ctx = ggml_init(params);
600
+
601
+ // pass 1: assign backends to ops with allocated inputs
602
+ for (int i = 0; i < graph->n_leafs; i++) {
603
+ struct ggml_tensor * leaf = graph->leafs[i];
604
+ if (node_allocr(leaf) != NULL) {
605
+ // do not overwrite user assignments
606
+ continue;
607
+ }
608
+ ggml_backend_t leaf_backend = ggml_get_backend(leaf);
609
+ if (leaf_backend == NULL && leaf->view_src != NULL) {
610
+ leaf_backend = ggml_get_backend(leaf->view_src);
611
+ }
612
+ if (leaf_backend != NULL) {
613
+ node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
614
+ }
615
+ }
616
+
617
+ for (int i = 0; i < graph->n_nodes; i++) {
618
+ struct ggml_tensor * node = graph->nodes[i];
619
+ if (node_allocr(node) != NULL) {
620
+ // do not overwrite user assignments
621
+ continue;
622
+ }
623
+ ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
624
+ if (node_backend != NULL) {
625
+ node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
626
+ }
627
+ }
628
+ //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
629
+
630
+ // pass 2: assign backends to ops from current assignments
631
+ // TODO:
632
+ // - reuse sched_backend_from_cur
633
+ for (int i = 0; i < graph->n_nodes; i++) {
634
+ struct ggml_tensor * node = graph->nodes[i];
635
+ ggml_tallocr_t node_allocr = node_allocr(node);
636
+ if (node_allocr == NULL) {
637
+ int cur_prio = INT_MAX;
638
+ size_t cur_size = 0;
639
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
640
+ struct ggml_tensor * src = node->src[j];
641
+ if (src == NULL) {
642
+ break;
643
+ }
644
+ ggml_tallocr_t src_allocr = node_allocr(src);
645
+ if (src_allocr != NULL) {
646
+ int src_prio = sched_allocr_prio(sched, src_allocr);
647
+ size_t src_size = ggml_nbytes(src);
648
+ if (src_prio < cur_prio && src_size >= cur_size) {
649
+ cur_prio = src_prio;
650
+ cur_size = src_size;
651
+ node_allocr = src_allocr;
652
+ sprintf(causes[hash_id(node)], "2.src%d", j);
653
+ }
654
+ }
655
+ }
656
+ if (node_allocr != NULL) {
657
+ node_allocr(node) = node_allocr;
658
+ }
659
+ }
660
+ }
661
+ //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
662
+
663
+ // pass 3: assign backends to remaining src from dst (should only be leafs)
664
+ for (int i = 0; i < graph->n_nodes; i++) {
665
+ struct ggml_tensor * node = graph->nodes[i];
666
+ ggml_tallocr_t node_allocr = node_allocr(node);
667
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
668
+ struct ggml_tensor * src = node->src[j];
669
+ if (src == NULL) {
670
+ break;
671
+ }
672
+ ggml_tallocr_t src_allocr = node_allocr(src);
673
+ if (src_allocr == NULL) {
674
+ node_allocr(src) = node_allocr;
675
+ }
676
+ }
677
+ }
678
+ //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
679
+
680
+ // pass 4: split graph, find tensors that need to be copied
681
+ // TODO:
682
+ // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
683
+ // find first backend
684
+ int cur_split = 0;
685
+ for (int i = 0; i < graph->n_nodes; i++) {
686
+ struct ggml_tensor * node = graph->nodes[i];
687
+ if (node->view_src == NULL) {
688
+ sched->splits[0].tallocr = node_allocr(node);
689
+ break;
690
+ }
691
+ }
692
+ sched->splits[0].i_start = 0;
693
+ sched->splits[0].n_inputs = 0;
694
+ memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
695
+ ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
696
+ size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
697
+ for (int i = 0; i < graph->n_nodes; i++) {
698
+ struct ggml_tensor * node = graph->nodes[i];
699
+
700
+ if (ggml_is_view_op(node->op)) {
701
+ continue;
702
+ }
703
+
704
+ ggml_tallocr_t node_allocr = node_allocr(node);
705
+
706
+ if (node_allocr != cur_allocr) {
707
+ sched->splits[cur_split].i_end = i;
708
+ cur_split++;
709
+ GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
710
+ sched->splits[cur_split].tallocr = node_allocr;
711
+ sched->splits[cur_split].i_start = i;
712
+ sched->splits[cur_split].n_inputs = 0;
713
+ memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
714
+ cur_allocr = node_allocr;
715
+ cur_backend_id = sched_allocr_prio(sched, cur_allocr);
716
+ }
717
+
718
+ // find inputs that are not on the same backend
719
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
720
+ struct ggml_tensor * src = node->src[j];
721
+ if (src == NULL) {
722
+ break;
723
+ }
724
+ ggml_tallocr_t src_allocr = node_allocr(src);
725
+ if (src_allocr != node_allocr) {
726
+ int n_inputs = sched->splits[cur_split].n_inputs++;
727
+ GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
728
+ sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
729
+
730
+ // create copies
731
+ size_t id = hash_id(src);
732
+ if (sched->node_copies[id][cur_backend_id] == NULL) {
733
+ struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
734
+ sched->node_copies[id][cur_backend_id] = tensor_copy;
735
+ node_allocr(tensor_copy) = cur_allocr;
736
+ ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
737
+ ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
738
+ }
739
+ node->src[j] = sched->node_copies[id][cur_backend_id];
740
+ }
741
+ }
742
+ }
743
+ sched->splits[cur_split].i_end = graph->n_nodes;
744
+ sched->n_splits = cur_split + 1;
745
+
746
+ //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
747
+
748
+ #if 1
749
+ // sanity check: all sources should have the same backend as the node
750
+ for (int i = 0; i < graph->n_nodes; i++) {
751
+ struct ggml_tensor * node = graph->nodes[i];
752
+ ggml_tallocr_t node_allocr = node_allocr(node);
753
+ if (node_allocr == NULL) {
754
+ fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
755
+ }
756
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
757
+ struct ggml_tensor * src = node->src[j];
758
+ if (src == NULL) {
759
+ break;
760
+ }
761
+ ggml_tallocr_t src_allocr = node_allocr(src);
762
+ if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
763
+ fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
764
+ node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
765
+ j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
766
+ }
767
+ }
768
+ }
769
+ #endif
770
+
771
+ // create copies of the graph for each split
772
+ // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
773
+ struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
774
+ for (int i = 0; i < sched->n_splits; i++) {
775
+ struct ggml_backend_sched_split * split = &sched->splits[i];
776
+ split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
777
+
778
+ // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
779
+ for (int j = 0; j < split->n_inputs; j++) {
780
+ struct ggml_tensor * input = split->inputs[j];
781
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
782
+ input_cpy->src[0] = input;
783
+ graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
784
+ }
785
+
786
+ for (int j = split->i_start; j < split->i_end; j++) {
787
+ graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
788
+ }
789
+ }
790
+ sched->graph = graph_copy;
791
+ }
792
+
793
+ static void sched_alloc_splits(ggml_backend_sched_t sched) {
794
+ ggml_gallocr_alloc_graph_n(
795
+ sched->galloc,
796
+ sched->graph,
797
+ sched->hash_set,
798
+ sched->node_talloc);
799
+ }
800
+
801
+ static void sched_compute_splits(ggml_backend_sched_t sched) {
802
+ uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
803
+ uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
804
+
805
+ struct ggml_backend_sched_split * splits = sched->splits;
806
+
807
+ for (int i = 0; i < sched->n_splits; i++) {
808
+ struct ggml_backend_sched_split * split = &splits[i];
809
+ ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
810
+ int split_backend_id = sched_backend_prio(sched, split_backend);
811
+
812
+ // copy the input tensors to the split backend
813
+ uint64_t copy_start_us = ggml_time_us();
814
+ for (int j = 0; j < split->n_inputs; j++) {
815
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
816
+ if (split->inputs[j]->buffer == NULL) {
817
+ if (split->inputs[j]->view_src == NULL) {
818
+ fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
819
+ exit(1);
820
+ }
821
+ struct ggml_tensor * view = split->inputs[j];
822
+ view->backend = view->view_src->backend;
823
+ view->buffer = view->view_src->buffer;
824
+ view->data = (char *)view->view_src->data + view->view_offs;
825
+ ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
826
+ }
827
+ if (input_cpy->buffer == NULL) {
828
+ fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
829
+ exit(1);
830
+ }
831
+ GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
832
+ GGML_ASSERT(input_cpy->buffer->backend == split_backend);
833
+ ggml_backend_tensor_copy(split->inputs[j], input_cpy);
834
+ }
835
+ // ggml_backend_synchronize(split_backend);
836
+ int64_t copy_end_us = ggml_time_us();
837
+ copy_us[split_backend_id] += copy_end_us - copy_start_us;
838
+
839
+ #if 0
840
+ char split_filename[GGML_MAX_NAME];
841
+ snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
842
+ ggml_graph_dump_dot(split->graph, NULL, split_filename);
843
+ #endif
844
+
845
+ uint64_t compute_start_us = ggml_time_us();
846
+ ggml_backend_graph_compute(split_backend, split->graph);
847
+ // ggml_backend_synchronize(split_backend);
848
+ uint64_t compute_end_us = ggml_time_us();
849
+ compute_us[split_backend_id] += compute_end_us - compute_start_us;
850
+ }
851
+
852
+ #if 0
853
+ // per-backend timings
854
+ fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
855
+ for (int i = 0; i < sched->n_backends; i++) {
856
+ if (copy_us[i] > 0 || compute_us[i] > 0) {
857
+ fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
858
+ }
859
+ }
860
+ #endif
861
+ }
862
+
863
+ static void sched_reset(ggml_backend_sched_t sched) {
864
+ for (int i = 0; i < sched->n_backends; i++) {
865
+ ggml_tallocr_reset(sched->tallocs[i]);
866
+ }
867
+ }
868
+
869
+ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
870
+ GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
871
+
872
+ struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
873
+ memset(sched, 0, sizeof(struct ggml_backend_sched));
874
+
875
+ fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
876
+
877
+ sched->n_backends = n_backends;
878
+ for (int i = 0; i < n_backends; i++) {
879
+ sched->backends[i] = backends[i];
880
+ }
881
+
882
+ sched->galloc = ggml_gallocr_new();
883
+
884
+ // init measure allocs for each backend
885
+ for (int i = 0; i < n_backends; i++) {
886
+ sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]);
887
+ }
888
+
889
+ return sched;
890
+ }
891
+
892
+ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
893
+ if (sched == NULL) {
894
+ return;
895
+ }
896
+ for (int i = 0; i < sched->n_backends; i++) {
897
+ ggml_tallocr_free(sched->tallocs[i]);
898
+ }
899
+ ggml_gallocr_free(sched->galloc);
900
+ free(sched->hash_set.keys);
901
+ free(sched->node_talloc);
902
+ free(sched->node_copies);
903
+ free(sched);
904
+ }
905
+
906
+ void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
907
+ // initialize hash tables
908
+ size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
909
+ sched->hash_set.size = hash_size;
910
+ sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
911
+ sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
912
+ sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
913
+
914
+ sched_split_graph(sched, measure_graph);
915
+ sched_alloc_splits(sched);
916
+
917
+ // allocate buffers and reset allocators
918
+ for (int i = 0; i < sched->n_backends; i++) {
919
+ size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
920
+ ggml_tallocr_free(sched->tallocs[i]);
921
+ sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size);
922
+ }
923
+
924
+ sched_reset(sched);
925
+ }
926
+
927
+ void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
928
+ GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
929
+
930
+ sched_split_graph(sched, graph);
931
+ sched_alloc_splits(sched);
932
+ sched_compute_splits(sched);
933
+ sched_reset(sched);
934
+ }
935
+
936
+ ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
937
+ int backend_index = sched_backend_prio(sched, backend);
938
+ return sched->tallocs[backend_index];
939
+ }
940
+
941
+ ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
942
+ int backend_index = sched_backend_prio(sched, backend);
943
+ return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
944
+ }
945
+
946
+ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
947
+ int backend_index = sched_backend_prio(sched, backend);
948
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
949
+ node_allocr(node) = sched->tallocs[backend_index];
950
+ }
ggml-backend.h ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-alloc.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ //
11
+ // Backend buffer
12
+ //
13
+
14
+ struct ggml_backend_buffer;
15
+ typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
16
+
17
+ // backend buffer functions
18
+ GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
19
+ GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
20
+ GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
21
+ GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
22
+ GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
23
+ GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
24
+ GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
25
+
26
+ //
27
+ // Backend
28
+ //
29
+
30
+ struct ggml_backend;
31
+ typedef struct ggml_backend * ggml_backend_t;
32
+ typedef void * ggml_backend_graph_plan_t;
33
+
34
+ GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
35
+
36
+ GGML_API const char * ggml_backend_name(ggml_backend_t backend);
37
+ GGML_API void ggml_backend_free(ggml_backend_t backend);
38
+
39
+ GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
40
+
41
+ GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
42
+
43
+ GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
44
+ GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
45
+
46
+ GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
47
+ GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
48
+
49
+ GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
50
+
51
+ GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
52
+
53
+ GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
54
+ GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
55
+ GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
56
+ GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
57
+
58
+ // tensor copy between different backends
59
+ GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
60
+
61
+ //
62
+ // CPU backend
63
+ //
64
+
65
+ GGML_API ggml_backend_t ggml_backend_cpu_init(void);
66
+
67
+ GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
68
+ GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
69
+
70
+ // Create a backend buffer from an existing pointer
71
+ GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
72
+
73
+
74
+ //
75
+ // Backend scheduler
76
+ //
77
+
78
+ // The backend scheduler allows for multiple backends to be used together
79
+ // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
80
+ // The backends are selected based on:
81
+ // - the backend that supports the operation
82
+ // - the location of the pre-allocated tensors (e.g. the weights)
83
+ /*
84
+ Example usage:
85
+
86
+ sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
87
+ // sched is initialized with measure allocators and cannot be used until allocated with a measure graph
88
+
89
+ // initialize buffers from a measure graph
90
+ measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
91
+
92
+ // in build_graph:
93
+ build_graph(...) {
94
+ // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
95
+ alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
96
+ ggml_allocr_alloc(alloc_cpu, tensor);
97
+
98
+ // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
99
+ struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
100
+ ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
101
+ }
102
+
103
+ // allocate backend buffers from measure graph
104
+ ggml_backend_sched_init_measure(sched, measure_graph);
105
+
106
+ // the scheduler is now ready to compute graphs
107
+
108
+ // compute
109
+ graph = build_graph(sched);
110
+ ggml_backend_sched_graph_compute(sched, graph);
111
+ */
112
+
113
+ struct ggml_backend_sched;
114
+ typedef struct ggml_backend_sched * ggml_backend_sched_t;
115
+
116
+ // Initialize a backend scheduler
117
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
118
+
119
+ GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
120
+
121
+ // Initialize backend buffers from a measure graph
122
+ GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
123
+
124
+ GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
125
+ GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
126
+
127
+ GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
128
+
129
+ // Allocate a graph on the backend scheduler
130
+ GGML_API void ggml_backend_sched_graph_compute(
131
+ ggml_backend_sched_t sched,
132
+ struct ggml_cgraph * graph);
133
+
134
+ #ifdef __cplusplus
135
+ }
136
+ #endif
ggml-cuda.cu CHANGED
The diff for this file is too large to render. See raw diff
 
ggml-cuda.h CHANGED
@@ -1,6 +1,7 @@
1
  #pragma once
2
 
3
  #include "ggml.h"
 
4
 
5
  #ifdef GGML_USE_HIPBLAS
6
  #define GGML_CUDA_NAME "ROCm"
@@ -31,6 +32,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens
31
 
32
  GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
33
  GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
 
34
 
35
  GGML_API void ggml_cuda_set_main_device(int main_device);
36
  GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
@@ -41,6 +43,9 @@ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, s
41
  GGML_API int ggml_cuda_get_device_count(void);
42
  GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
43
 
 
 
 
44
  #ifdef __cplusplus
45
  }
46
  #endif
 
1
  #pragma once
2
 
3
  #include "ggml.h"
4
+ #include "ggml-backend.h"
5
 
6
  #ifdef GGML_USE_HIPBLAS
7
  #define GGML_CUDA_NAME "ROCm"
 
32
 
33
  GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
34
  GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
35
+ GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
36
 
37
  GGML_API void ggml_cuda_set_main_device(int main_device);
38
  GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
 
43
  GGML_API int ggml_cuda_get_device_count(void);
44
  GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
45
 
46
+ // backend API
47
+ GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
48
+
49
  #ifdef __cplusplus
50
  }
51
  #endif
ggml-impl.h ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+
5
+ // GGML internal header
6
+
7
+ #include <assert.h>
8
+ #include <stddef.h>
9
+ #include <stdbool.h>
10
+ #include <string.h> // memcpy
11
+ #include <math.h> // fabsf
12
+
13
+ #ifdef __cplusplus
14
+ extern "C" {
15
+ #endif
16
+
17
+ // static_assert should be a #define, but if it's not,
18
+ // fall back to the _Static_assert C11 keyword.
19
+ // if C99 - static_assert is noop
20
+ // ref: https://stackoverflow.com/a/53923785/4039976
21
+ #ifndef static_assert
22
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
23
+ #define static_assert(cond, msg) _Static_assert(cond, msg)
24
+ #else
25
+ #define static_assert(cond, msg) struct global_scope_noop_trick
26
+ #endif
27
+ #endif
28
+
29
+ // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
30
+ #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
31
+ #ifndef __FMA__
32
+ #define __FMA__
33
+ #endif
34
+ #ifndef __F16C__
35
+ #define __F16C__
36
+ #endif
37
+ #ifndef __SSE3__
38
+ #define __SSE3__
39
+ #endif
40
+ #endif
41
+
42
+ // 16-bit float
43
+ // on Arm, we use __fp16
44
+ // on x86, we use uint16_t
45
+ #if defined(__ARM_NEON) && !defined(_MSC_VER)
46
+
47
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
48
+ //
49
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
50
+ //
51
+ #include <arm_neon.h>
52
+
53
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
54
+ #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
55
+
56
+ #define GGML_FP16_TO_FP32(x) ((float) (x))
57
+ #define GGML_FP32_TO_FP16(x) (x)
58
+
59
+ #else
60
+
61
+ #ifdef __wasm_simd128__
62
+ #include <wasm_simd128.h>
63
+ #else
64
+ #ifdef __POWER9_VECTOR__
65
+ #include <altivec.h>
66
+ #undef bool
67
+ #define bool _Bool
68
+ #else
69
+ #if defined(_MSC_VER) || defined(__MINGW32__)
70
+ #include <intrin.h>
71
+ #else
72
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
73
+ #if !defined(__riscv)
74
+ #include <immintrin.h>
75
+ #endif
76
+ #endif
77
+ #endif
78
+ #endif
79
+ #endif
80
+
81
+ #ifdef __riscv_v_intrinsic
82
+ #include <riscv_vector.h>
83
+ #endif
84
+
85
+ #ifdef __F16C__
86
+
87
+ #ifdef _MSC_VER
88
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
89
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
90
+ #else
91
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
92
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
93
+ #endif
94
+
95
+ #elif defined(__POWER9_VECTOR__)
96
+
97
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
98
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
99
+ /* the inline asm below is about 12% faster than the lookup method */
100
+ #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
101
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
102
+
103
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
104
+ register float f;
105
+ register double d;
106
+ __asm__(
107
+ "mtfprd %0,%2\n"
108
+ "xscvhpdp %0,%0\n"
109
+ "frsp %1,%0\n" :
110
+ /* temp */ "=d"(d),
111
+ /* out */ "=f"(f):
112
+ /* in */ "r"(h));
113
+ return f;
114
+ }
115
+
116
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
117
+ register double d;
118
+ register ggml_fp16_t r;
119
+ __asm__( /* xscvdphp can work on double or single precision */
120
+ "xscvdphp %0,%2\n"
121
+ "mffprd %1,%0\n" :
122
+ /* temp */ "=d"(d),
123
+ /* out */ "=r"(r):
124
+ /* in */ "f"(f));
125
+ return r;
126
+ }
127
+
128
+ #else
129
+
130
+ // FP16 <-> FP32
131
+ // ref: https://github.com/Maratyszcza/FP16
132
+
133
+ static inline float fp32_from_bits(uint32_t w) {
134
+ union {
135
+ uint32_t as_bits;
136
+ float as_value;
137
+ } fp32;
138
+ fp32.as_bits = w;
139
+ return fp32.as_value;
140
+ }
141
+
142
+ static inline uint32_t fp32_to_bits(float f) {
143
+ union {
144
+ float as_value;
145
+ uint32_t as_bits;
146
+ } fp32;
147
+ fp32.as_value = f;
148
+ return fp32.as_bits;
149
+ }
150
+
151
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
152
+ const uint32_t w = (uint32_t) h << 16;
153
+ const uint32_t sign = w & UINT32_C(0x80000000);
154
+ const uint32_t two_w = w + w;
155
+
156
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
157
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
158
+ const float exp_scale = 0x1.0p-112f;
159
+ #else
160
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
161
+ #endif
162
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
163
+
164
+ const uint32_t magic_mask = UINT32_C(126) << 23;
165
+ const float magic_bias = 0.5f;
166
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
167
+
168
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
169
+ const uint32_t result = sign |
170
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
171
+ return fp32_from_bits(result);
172
+ }
173
+
174
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
175
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
176
+ const float scale_to_inf = 0x1.0p+112f;
177
+ const float scale_to_zero = 0x1.0p-110f;
178
+ #else
179
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
180
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
181
+ #endif
182
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
183
+
184
+ const uint32_t w = fp32_to_bits(f);
185
+ const uint32_t shl1_w = w + w;
186
+ const uint32_t sign = w & UINT32_C(0x80000000);
187
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
188
+ if (bias < UINT32_C(0x71000000)) {
189
+ bias = UINT32_C(0x71000000);
190
+ }
191
+
192
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
193
+ const uint32_t bits = fp32_to_bits(base);
194
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
195
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
196
+ const uint32_t nonsign = exp_bits + mantissa_bits;
197
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
198
+ }
199
+
200
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
201
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
202
+
203
+ #endif // __F16C__
204
+
205
+ #endif // __ARM_NEON
206
+
207
+ // precomputed f32 table for f16 (256 KB)
208
+ // defined in ggml.c, initialized in ggml_init()
209
+ extern float ggml_table_f32_f16[1 << 16];
210
+
211
+ // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
212
+ // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
213
+ // This is also true for POWER9.
214
+ #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
215
+
216
+ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
217
+ uint16_t s;
218
+ memcpy(&s, &f, sizeof(uint16_t));
219
+ return ggml_table_f32_f16[s];
220
+ }
221
+
222
+ #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
223
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
224
+
225
+ #endif
226
+
227
+ #define GGML_HASHTABLE_FULL ((size_t)-1)
228
+ #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
229
+
230
+ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
231
+
232
+ // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
233
+ size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
234
+
235
+ // returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
236
+ size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
237
+
238
+ // return index, asserts if table is full
239
+ size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key);
240
+
241
+ #ifdef __cplusplus
242
+ }
243
+ #endif
ggml-metal.h CHANGED
@@ -19,6 +19,9 @@
19
 
20
  #pragma once
21
 
 
 
 
22
  #include <stddef.h>
23
  #include <stdbool.h>
24
 
@@ -33,8 +36,15 @@ struct ggml_cgraph;
33
  extern "C" {
34
  #endif
35
 
 
 
 
 
 
36
  struct ggml_metal_context;
37
 
 
 
38
  // number of command buffers to use
39
  struct ggml_metal_context * ggml_metal_init(int n_cb);
40
  void ggml_metal_free(struct ggml_metal_context * ctx);
@@ -79,6 +89,17 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
79
  // creates gf->n_threads command buffers in parallel
80
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  #ifdef __cplusplus
83
  }
84
  #endif
 
19
 
20
  #pragma once
21
 
22
+ #include "ggml.h"
23
+ #include "ggml-backend.h"
24
+
25
  #include <stddef.h>
26
  #include <stdbool.h>
27
 
 
36
  extern "C" {
37
  #endif
38
 
39
+ //
40
+ // internal API
41
+ // temporary exposed to user-code
42
+ //
43
+
44
  struct ggml_metal_context;
45
 
46
+ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
47
+
48
  // number of command buffers to use
49
  struct ggml_metal_context * ggml_metal_init(int n_cb);
50
  void ggml_metal_free(struct ggml_metal_context * ctx);
 
89
  // creates gf->n_threads command buffers in parallel
90
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
91
 
92
+ //
93
+ // backend API
94
+ // user-code should use only these functions
95
+ //
96
+
97
+ GGML_API ggml_backend_t ggml_backend_metal_init(void);
98
+
99
+ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
100
+
101
+ GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
102
+
103
  #ifdef __cplusplus
104
  }
105
  #endif
ggml-metal.m CHANGED
@@ -1,5 +1,6 @@
1
  #import "ggml-metal.h"
2
 
 
3
  #import "ggml.h"
4
 
5
  #import <Foundation/Foundation.h>
@@ -11,16 +12,19 @@
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
 
14
- // TODO: temporary - reuse llama.cpp logging
15
  #ifdef GGML_METAL_NDEBUG
16
- #define metal_printf(...)
 
 
17
  #else
18
- #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
 
 
19
  #endif
20
 
21
  #define UNUSED(x) (void)(x)
22
 
23
- #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
24
 
25
  struct ggml_metal_buffer {
26
  const char * name;
@@ -59,6 +63,7 @@ struct ggml_metal_context {
59
  GGML_METAL_DECL_KERNEL(mul);
60
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
61
  GGML_METAL_DECL_KERNEL(scale);
 
62
  GGML_METAL_DECL_KERNEL(silu);
63
  GGML_METAL_DECL_KERNEL(relu);
64
  GGML_METAL_DECL_KERNEL(gelu);
@@ -70,6 +75,8 @@ struct ggml_metal_context {
70
  GGML_METAL_DECL_KERNEL(get_rows_f16);
71
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
72
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
 
 
73
  GGML_METAL_DECL_KERNEL(get_rows_q8_0);
74
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
75
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -78,33 +85,40 @@ struct ggml_metal_context {
78
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
79
  GGML_METAL_DECL_KERNEL(rms_norm);
80
  GGML_METAL_DECL_KERNEL(norm);
81
- GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
82
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
83
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
84
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
85
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
86
- GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
87
- GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
88
- GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
89
- GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
90
- GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
91
- GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
92
- GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
 
 
93
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
94
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
95
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
96
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
 
 
97
  GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
98
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
99
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
100
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
101
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
102
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
103
- GGML_METAL_DECL_KERNEL(rope);
 
104
  GGML_METAL_DECL_KERNEL(alibi_f32);
105
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
106
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
107
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
 
 
108
 
109
  #undef GGML_METAL_DECL_KERNEL
110
  };
@@ -120,8 +134,37 @@ static NSString * const msl_library_source = @"see metal.metal";
120
  @implementation GGMLMetalClass
121
  @end
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
124
- metal_printf("%s: allocating\n", __func__);
125
 
126
  id <MTLDevice> device;
127
  NSString * s;
@@ -131,14 +174,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
131
  NSArray * devices = MTLCopyAllDevices();
132
  for (device in devices) {
133
  s = [device name];
134
- metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
135
  }
136
  #endif
137
 
138
  // Pick and show default Metal device
139
  device = MTLCreateSystemDefaultDevice();
140
  s = [device name];
141
- metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
142
 
143
  // Configure context
144
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
@@ -150,68 +193,69 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
150
 
151
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
152
 
153
- #ifdef GGML_SWIFT
154
- // load the default.metallib file
155
  {
156
- NSError * error = nil;
157
-
158
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
159
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
160
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
161
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
162
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
163
-
164
- // Load the metallib file into a Metal library
165
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
166
-
167
- if (error) {
168
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
169
- return NULL;
170
- }
171
- }
172
  #else
173
- UNUSED(msl_library_source);
174
-
175
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
176
- {
177
  NSError * error = nil;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
180
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
181
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
182
- metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
183
-
184
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
185
- if (error) {
186
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
187
- return NULL;
188
- }
189
-
190
  #ifdef GGML_QKK_64
191
- MTLCompileOptions* options = [MTLCompileOptions new];
192
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
193
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
194
- #else
195
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
196
  #endif
 
 
 
197
  if (error) {
198
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
199
  return NULL;
200
  }
201
  }
202
- #endif
203
 
204
  // load kernels
205
  {
206
  NSError * error = nil;
 
 
 
 
 
 
207
  #define GGML_METAL_ADD_KERNEL(name) \
208
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
209
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
210
- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
211
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
212
- (int) ctx->pipeline_##name.threadExecutionWidth); \
213
  if (error) { \
214
- metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
215
  return NULL; \
216
  }
217
 
@@ -220,6 +264,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
220
  GGML_METAL_ADD_KERNEL(mul);
221
  GGML_METAL_ADD_KERNEL(mul_row);
222
  GGML_METAL_ADD_KERNEL(scale);
 
223
  GGML_METAL_ADD_KERNEL(silu);
224
  GGML_METAL_ADD_KERNEL(relu);
225
  GGML_METAL_ADD_KERNEL(gelu);
@@ -231,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
231
  GGML_METAL_ADD_KERNEL(get_rows_f16);
232
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
233
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
 
 
234
  GGML_METAL_ADD_KERNEL(get_rows_q8_0);
235
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
236
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -239,44 +286,66 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
239
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
240
  GGML_METAL_ADD_KERNEL(rms_norm);
241
  GGML_METAL_ADD_KERNEL(norm);
242
- GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
243
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
244
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
245
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
246
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
247
- GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
248
- GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
249
- GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
250
- GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
251
- GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
252
- GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
253
- GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
254
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
255
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
256
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
257
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
258
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
259
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
260
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
261
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
262
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
263
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
264
- GGML_METAL_ADD_KERNEL(rope);
 
 
 
 
 
 
 
265
  GGML_METAL_ADD_KERNEL(alibi_f32);
266
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
267
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
268
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
 
 
269
 
270
  #undef GGML_METAL_ADD_KERNEL
271
  }
272
 
273
- metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
  #if TARGET_OS_OSX
275
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  if (ctx->device.maxTransferRate != 0) {
277
- metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
278
  } else {
279
- metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
280
  }
281
  #endif
282
 
@@ -284,7 +353,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
284
  }
285
 
286
  void ggml_metal_free(struct ggml_metal_context * ctx) {
287
- metal_printf("%s: deallocating\n", __func__);
288
  #define GGML_METAL_DEL_KERNEL(name) \
289
  [ctx->function_##name release]; \
290
  [ctx->pipeline_##name release];
@@ -294,6 +363,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
294
  GGML_METAL_DEL_KERNEL(mul);
295
  GGML_METAL_DEL_KERNEL(mul_row);
296
  GGML_METAL_DEL_KERNEL(scale);
 
297
  GGML_METAL_DEL_KERNEL(silu);
298
  GGML_METAL_DEL_KERNEL(relu);
299
  GGML_METAL_DEL_KERNEL(gelu);
@@ -305,6 +375,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
305
  GGML_METAL_DEL_KERNEL(get_rows_f16);
306
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
307
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
 
 
308
  GGML_METAL_DEL_KERNEL(get_rows_q8_0);
309
  GGML_METAL_DEL_KERNEL(get_rows_q2_K);
310
  GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -313,33 +385,42 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
313
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
314
  GGML_METAL_DEL_KERNEL(rms_norm);
315
  GGML_METAL_DEL_KERNEL(norm);
316
- GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
317
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
318
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
319
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
320
- GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
321
- GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
322
- GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
323
- GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
324
- GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
325
- GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
326
- GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
327
- GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
328
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
329
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
330
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
331
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
332
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
333
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
334
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
335
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
336
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
337
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
338
- GGML_METAL_DEL_KERNEL(rope);
 
 
 
 
 
 
 
339
  GGML_METAL_DEL_KERNEL(alibi_f32);
340
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
341
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
342
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
 
 
343
 
344
  #undef GGML_METAL_DEL_KERNEL
345
 
@@ -360,7 +441,7 @@ void * ggml_metal_host_malloc(size_t n) {
360
  void * data = NULL;
361
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
362
  if (result != 0) {
363
- metal_printf("%s: error: posix_memalign failed\n", __func__);
364
  return NULL;
365
  }
366
 
@@ -388,7 +469,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
388
  // Metal buffer based on the host memory pointer
389
  //
390
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
391
- //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
392
 
393
  const int64_t tsize = ggml_nbytes(t);
394
 
@@ -396,17 +477,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
396
  for (int i = 0; i < ctx->n_buffers; ++i) {
397
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
398
 
399
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
400
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
401
  *offs = (size_t) ioffs;
402
 
403
- //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
404
 
405
  return ctx->buffers[i].metal;
406
  }
407
  }
408
 
409
- metal_printf("%s: error: buffer is nil\n", __func__);
410
 
411
  return nil;
412
  }
@@ -418,7 +499,7 @@ bool ggml_metal_add_buffer(
418
  size_t size,
419
  size_t max_size) {
420
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
421
- metal_printf("%s: too many buffers\n", __func__);
422
  return false;
423
  }
424
 
@@ -428,7 +509,7 @@ bool ggml_metal_add_buffer(
428
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
429
 
430
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
431
- metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
432
  return false;
433
  }
434
  }
@@ -449,11 +530,11 @@ bool ggml_metal_add_buffer(
449
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
450
 
451
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
452
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
453
  return false;
454
  }
455
 
456
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
457
 
458
  ++ctx->n_buffers;
459
  } else {
@@ -473,13 +554,13 @@ bool ggml_metal_add_buffer(
473
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
474
 
475
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
476
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
477
  return false;
478
  }
479
 
480
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
481
  if (i + size_step < size) {
482
- metal_printf("\n");
483
  }
484
 
485
  ++ctx->n_buffers;
@@ -487,17 +568,17 @@ bool ggml_metal_add_buffer(
487
  }
488
 
489
  #if TARGET_OS_OSX
490
- metal_printf(", (%8.2f / %8.2f)",
491
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
492
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
493
 
494
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
495
- metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
496
  } else {
497
- metal_printf("\n");
498
  }
499
  #else
500
- metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
501
  #endif
502
  }
503
 
@@ -610,7 +691,7 @@ void ggml_metal_graph_find_concurrency(
610
  }
611
 
612
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
613
- metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
614
  }
615
  }
616
 
@@ -664,12 +745,26 @@ void ggml_metal_graph_compute(
664
  continue;
665
  }
666
 
667
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
668
 
669
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
670
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
671
  struct ggml_tensor * dst = gf->nodes[i];
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
674
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
675
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -708,53 +803,117 @@ void ggml_metal_graph_compute(
708
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
709
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
710
 
711
- //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
712
  //if (src0) {
713
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
714
  // ggml_is_contiguous(src0), src0->name);
715
  //}
716
  //if (src1) {
717
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
718
  // ggml_is_contiguous(src1), src1->name);
719
  //}
720
  //if (dst) {
721
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
722
  // dst->name);
723
  //}
724
 
725
  switch (dst->op) {
726
- case GGML_OP_NONE:
727
- case GGML_OP_RESHAPE:
728
- case GGML_OP_VIEW:
729
- case GGML_OP_TRANSPOSE:
730
- case GGML_OP_PERMUTE:
731
  {
732
- // noop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  } break;
734
  case GGML_OP_ADD:
735
  {
736
  GGML_ASSERT(ggml_is_contiguous(src0));
737
  GGML_ASSERT(ggml_is_contiguous(src1));
738
 
739
- // utilize float4
740
- GGML_ASSERT(ne00 % 4 == 0);
741
- const int64_t nb = ne00/4;
742
 
743
- if (ggml_nelements(src1) == ne10) {
 
 
744
  // src1 is a row
745
  GGML_ASSERT(ne11 == 1);
 
 
746
  [encoder setComputePipelineState:ctx->pipeline_add_row];
 
 
747
  } else {
748
  [encoder setComputePipelineState:ctx->pipeline_add];
749
  }
750
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
751
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
752
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
753
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
754
-
755
- const int64_t n = ggml_nelements(dst)/4;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
758
  } break;
759
  case GGML_OP_MUL:
760
  {
@@ -787,13 +946,19 @@ void ggml_metal_graph_compute(
787
 
788
  const float scale = *(const float *) src1->data;
789
 
790
- [encoder setComputePipelineState:ctx->pipeline_scale];
 
 
 
 
 
 
 
 
791
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
792
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
793
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
794
 
795
- const int64_t n = ggml_nelements(dst)/4;
796
-
797
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
798
  } break;
799
  case GGML_OP_UNARY:
@@ -804,9 +969,10 @@ void ggml_metal_graph_compute(
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
805
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
806
 
807
- const int64_t n = ggml_nelements(dst)/4;
 
808
 
809
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
810
  } break;
811
  case GGML_UNARY_OP_RELU:
812
  {
@@ -824,23 +990,39 @@ void ggml_metal_graph_compute(
824
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
825
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
826
 
827
- const int64_t n = ggml_nelements(dst)/4;
 
828
 
829
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
830
  } break;
831
  default:
832
  {
833
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
834
  GGML_ASSERT(false);
835
  }
836
  } break;
 
 
 
 
 
 
 
 
 
 
 
837
  case GGML_OP_SOFT_MAX:
838
  {
839
- const int nth = 32;
840
 
841
  if (ne00%4 == 0) {
842
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
843
  } else {
 
 
 
 
844
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
845
  }
846
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -848,8 +1030,9 @@ void ggml_metal_graph_compute(
848
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
849
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
850
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
 
851
 
852
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
853
  } break;
854
  case GGML_OP_DIAG_MASK_INF:
855
  {
@@ -875,26 +1058,53 @@ void ggml_metal_graph_compute(
875
  } break;
876
  case GGML_OP_MUL_MAT:
877
  {
878
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
879
-
880
  GGML_ASSERT(ne00 == ne10);
881
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
882
- uint gqa = ne12/ne02;
883
  GGML_ASSERT(ne03 == ne13);
884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
886
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
887
- if (!ggml_is_transposed(src0) &&
 
888
  !ggml_is_transposed(src1) &&
889
  src1t == GGML_TYPE_F32 &&
890
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
891
- ne00%32 == 0 &&
892
- ne11 > 1) {
893
  switch (src0->type) {
894
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
895
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
896
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
897
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
 
 
898
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
899
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
900
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -918,17 +1128,18 @@ void ggml_metal_graph_compute(
918
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
919
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
920
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
921
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
922
  } else {
923
  int nth0 = 32;
924
  int nth1 = 1;
925
  int nrows = 1;
 
926
 
927
  // use custom matrix x vector kernel
928
  switch (src0t) {
929
  case GGML_TYPE_F32:
930
  {
931
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
932
  nrows = 4;
933
  } break;
934
  case GGML_TYPE_F16:
@@ -936,12 +1147,12 @@ void ggml_metal_graph_compute(
936
  nth0 = 32;
937
  nth1 = 1;
938
  if (ne11 * ne12 < 4) {
939
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
940
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
941
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
942
  nrows = ne11;
943
  } else {
944
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
945
  nrows = 4;
946
  }
947
  } break;
@@ -952,7 +1163,7 @@ void ggml_metal_graph_compute(
952
 
953
  nth0 = 8;
954
  nth1 = 8;
955
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
956
  } break;
957
  case GGML_TYPE_Q4_1:
958
  {
@@ -961,7 +1172,25 @@ void ggml_metal_graph_compute(
961
 
962
  nth0 = 8;
963
  nth1 = 8;
964
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
  } break;
966
  case GGML_TYPE_Q8_0:
967
  {
@@ -970,7 +1199,7 @@ void ggml_metal_graph_compute(
970
 
971
  nth0 = 8;
972
  nth1 = 8;
973
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
974
  } break;
975
  case GGML_TYPE_Q2_K:
976
  {
@@ -979,7 +1208,7 @@ void ggml_metal_graph_compute(
979
 
980
  nth0 = 2;
981
  nth1 = 32;
982
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
983
  } break;
984
  case GGML_TYPE_Q3_K:
985
  {
@@ -988,7 +1217,7 @@ void ggml_metal_graph_compute(
988
 
989
  nth0 = 2;
990
  nth1 = 32;
991
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
992
  } break;
993
  case GGML_TYPE_Q4_K:
994
  {
@@ -997,7 +1226,7 @@ void ggml_metal_graph_compute(
997
 
998
  nth0 = 4; //1;
999
  nth1 = 8; //32;
1000
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1001
  } break;
1002
  case GGML_TYPE_Q5_K:
1003
  {
@@ -1006,7 +1235,7 @@ void ggml_metal_graph_compute(
1006
 
1007
  nth0 = 2;
1008
  nth1 = 32;
1009
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1010
  } break;
1011
  case GGML_TYPE_Q6_K:
1012
  {
@@ -1015,11 +1244,11 @@ void ggml_metal_graph_compute(
1015
 
1016
  nth0 = 2;
1017
  nth1 = 32;
1018
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1019
  } break;
1020
  default:
1021
  {
1022
- metal_printf("Asserting on type %d\n",(int)src0t);
1023
  GGML_ASSERT(false && "not implemented");
1024
  }
1025
  };
@@ -1043,8 +1272,9 @@ void ggml_metal_graph_compute(
1043
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1044
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1045
 
1046
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1047
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
 
1048
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1049
  }
1050
  else if (src0t == GGML_TYPE_Q4_K) {
@@ -1075,6 +1305,8 @@ void ggml_metal_graph_compute(
1075
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1076
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1077
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
 
 
1078
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1079
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1080
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -1097,10 +1329,12 @@ void ggml_metal_graph_compute(
1097
  } break;
1098
  case GGML_OP_RMS_NORM:
1099
  {
 
 
1100
  float eps;
1101
  memcpy(&eps, dst->op_params, sizeof(float));
1102
 
1103
- const int nth = 512;
1104
 
1105
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1106
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1119,7 +1353,7 @@ void ggml_metal_graph_compute(
1119
  float eps;
1120
  memcpy(&eps, dst->op_params, sizeof(float));
1121
 
1122
- const int nth = 256;
1123
 
1124
  [encoder setComputePipelineState:ctx->pipeline_norm];
1125
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1137,17 +1371,16 @@ void ggml_metal_graph_compute(
1137
  {
1138
  GGML_ASSERT((src0t == GGML_TYPE_F32));
1139
 
1140
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
 
 
1141
  const int n_head = ((int32_t *) dst->op_params)[1];
1142
  float max_bias;
1143
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1144
 
1145
- if (__builtin_popcount(n_head) != 1) {
1146
- GGML_ASSERT(false && "only power-of-two n_head implemented");
1147
- }
1148
-
1149
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1150
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
 
1151
 
1152
  [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1153
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1168,55 +1401,74 @@ void ggml_metal_graph_compute(
1168
  [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1169
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1170
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1171
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1172
-
1173
- const int nth = 32;
1174
 
1175
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1176
  } break;
1177
  case GGML_OP_ROPE:
1178
  {
1179
- const int n_past = ((int32_t *) dst->op_params)[0];
1180
- const int n_dims = ((int32_t *) dst->op_params)[1];
1181
- const int mode = ((int32_t *) dst->op_params)[2];
1182
 
1183
- float freq_base;
1184
- float freq_scale;
1185
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1186
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1187
 
1188
- [encoder setComputePipelineState:ctx->pipeline_rope];
1189
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1190
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1191
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1192
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1193
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1194
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1195
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1196
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1197
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1198
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1199
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1200
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1201
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1202
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1203
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1204
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1205
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1206
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1207
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1208
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1209
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1210
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1211
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1212
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1214
  } break;
1215
  case GGML_OP_DUP:
1216
  case GGML_OP_CPY:
1217
  case GGML_OP_CONT:
1218
  {
1219
- const int nth = 32;
1220
 
1221
  switch (src0t) {
1222
  case GGML_TYPE_F32:
@@ -1261,7 +1513,7 @@ void ggml_metal_graph_compute(
1261
  } break;
1262
  default:
1263
  {
1264
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1265
  GGML_ASSERT(false);
1266
  }
1267
  }
@@ -1286,10 +1538,147 @@ void ggml_metal_graph_compute(
1286
 
1287
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1288
  if (status != MTLCommandBufferStatusCompleted) {
1289
- metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1290
  GGML_ASSERT(false);
1291
  }
1292
  }
1293
 
1294
  }
1295
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #import "ggml-metal.h"
2
 
3
+ #import "ggml-backend-impl.h"
4
  #import "ggml.h"
5
 
6
  #import <Foundation/Foundation.h>
 
12
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
 
 
15
  #ifdef GGML_METAL_NDEBUG
16
+ #define GGML_METAL_LOG_INFO(...)
17
+ #define GGML_METAL_LOG_WARN(...)
18
+ #define GGML_METAL_LOG_ERROR(...)
19
  #else
20
+ #define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
+ #define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
+ #define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
23
  #endif
24
 
25
  #define UNUSED(x) (void)(x)
26
 
27
+ #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
28
 
29
  struct ggml_metal_buffer {
30
  const char * name;
 
63
  GGML_METAL_DECL_KERNEL(mul);
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
  GGML_METAL_DECL_KERNEL(scale);
66
+ GGML_METAL_DECL_KERNEL(scale_4);
67
  GGML_METAL_DECL_KERNEL(silu);
68
  GGML_METAL_DECL_KERNEL(relu);
69
  GGML_METAL_DECL_KERNEL(gelu);
 
75
  GGML_METAL_DECL_KERNEL(get_rows_f16);
76
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
77
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
78
+ GGML_METAL_DECL_KERNEL(get_rows_q5_0);
79
+ GGML_METAL_DECL_KERNEL(get_rows_q5_1);
80
  GGML_METAL_DECL_KERNEL(get_rows_q8_0);
81
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
82
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
 
85
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
86
  GGML_METAL_DECL_KERNEL(rms_norm);
87
  GGML_METAL_DECL_KERNEL(norm);
88
+ GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
91
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
95
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
96
+ GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
97
+ GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
98
+ GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
99
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
100
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
101
+ GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
102
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
103
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
104
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
105
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
106
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
107
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
108
  GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
109
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
110
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
111
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
112
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
113
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
114
+ GGML_METAL_DECL_KERNEL(rope_f32);
115
+ GGML_METAL_DECL_KERNEL(rope_f16);
116
  GGML_METAL_DECL_KERNEL(alibi_f32);
117
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
118
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
119
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
120
+ GGML_METAL_DECL_KERNEL(concat);
121
+ GGML_METAL_DECL_KERNEL(sqr);
122
 
123
  #undef GGML_METAL_DECL_KERNEL
124
  };
 
134
  @implementation GGMLMetalClass
135
  @end
136
 
137
+ ggml_log_callback ggml_metal_log_callback = NULL;
138
+ void * ggml_metal_log_user_data = NULL;
139
+
140
+ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
141
+ ggml_metal_log_callback = log_callback;
142
+ ggml_metal_log_user_data = user_data;
143
+ }
144
+
145
+ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
146
+ if (ggml_metal_log_callback != NULL) {
147
+ va_list args;
148
+ va_start(args, format);
149
+ char buffer[128];
150
+ int len = vsnprintf(buffer, 128, format, args);
151
+ if (len < 128) {
152
+ ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
153
+ } else {
154
+ char* buffer2 = malloc(len+1);
155
+ vsnprintf(buffer2, len+1, format, args);
156
+ buffer2[len] = 0;
157
+ ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
158
+ free(buffer2);
159
+ }
160
+ va_end(args);
161
+ }
162
+ }
163
+
164
+
165
+
166
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
167
+ GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
168
 
169
  id <MTLDevice> device;
170
  NSString * s;
 
174
  NSArray * devices = MTLCopyAllDevices();
175
  for (device in devices) {
176
  s = [device name];
177
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
178
  }
179
  #endif
180
 
181
  // Pick and show default Metal device
182
  device = MTLCreateSystemDefaultDevice();
183
  s = [device name];
184
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
185
 
186
  // Configure context
187
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
 
193
 
194
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
195
 
196
+ // load library
 
197
  {
198
+ NSBundle * bundle = nil;
199
+ #ifdef SWIFT_PACKAGE
200
+ bundle = SWIFTPM_MODULE_BUNDLE;
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  #else
202
+ bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
203
+ #endif
 
 
204
  NSError * error = nil;
205
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
206
+ if (libPath != nil) {
207
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
208
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
209
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
210
+ } else {
211
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
212
+
213
+ NSString * sourcePath;
214
+ NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
215
+ if (ggmlMetalPathResources) {
216
+ sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
217
+ } else {
218
+ sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
219
+ }
220
+ if (sourcePath == nil) {
221
+ GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
222
+ sourcePath = @"ggml-metal.metal";
223
+ }
224
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
225
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
226
+ if (error) {
227
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
228
+ return NULL;
229
+ }
230
 
231
+ MTLCompileOptions* options = nil;
 
 
 
 
 
 
 
 
 
 
232
  #ifdef GGML_QKK_64
233
+ options = [MTLCompileOptions new];
234
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
 
 
 
235
  #endif
236
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
237
+ }
238
+
239
  if (error) {
240
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
241
  return NULL;
242
  }
243
  }
 
244
 
245
  // load kernels
246
  {
247
  NSError * error = nil;
248
+
249
+ /*
250
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
251
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
252
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
253
+ */
254
  #define GGML_METAL_ADD_KERNEL(name) \
255
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
256
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
 
 
 
257
  if (error) { \
258
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
259
  return NULL; \
260
  }
261
 
 
264
  GGML_METAL_ADD_KERNEL(mul);
265
  GGML_METAL_ADD_KERNEL(mul_row);
266
  GGML_METAL_ADD_KERNEL(scale);
267
+ GGML_METAL_ADD_KERNEL(scale_4);
268
  GGML_METAL_ADD_KERNEL(silu);
269
  GGML_METAL_ADD_KERNEL(relu);
270
  GGML_METAL_ADD_KERNEL(gelu);
 
276
  GGML_METAL_ADD_KERNEL(get_rows_f16);
277
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
278
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
279
+ GGML_METAL_ADD_KERNEL(get_rows_q5_0);
280
+ GGML_METAL_ADD_KERNEL(get_rows_q5_1);
281
  GGML_METAL_ADD_KERNEL(get_rows_q8_0);
282
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
283
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
 
286
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
287
  GGML_METAL_ADD_KERNEL(rms_norm);
288
  GGML_METAL_ADD_KERNEL(norm);
289
+ GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
290
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
293
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
294
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
295
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
296
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
297
+ GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
298
+ GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
299
+ GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
300
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
301
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
302
+ GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
303
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
304
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
305
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
306
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
307
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
308
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
309
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
310
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
311
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
312
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
313
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
314
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
315
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
316
+ }
317
+ GGML_METAL_ADD_KERNEL(rope_f32);
318
+ GGML_METAL_ADD_KERNEL(rope_f16);
319
  GGML_METAL_ADD_KERNEL(alibi_f32);
320
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
321
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
322
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
323
+ GGML_METAL_ADD_KERNEL(concat);
324
+ GGML_METAL_ADD_KERNEL(sqr);
325
 
326
  #undef GGML_METAL_ADD_KERNEL
327
  }
328
 
 
329
  #if TARGET_OS_OSX
330
+ // print MTL GPU family:
331
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
332
+
333
+ // determine max supported GPU family
334
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
335
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
336
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
337
+ if ([ctx->device supportsFamily:i]) {
338
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
339
+ break;
340
+ }
341
+ }
342
+
343
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
344
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
345
  if (ctx->device.maxTransferRate != 0) {
346
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
347
  } else {
348
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
349
  }
350
  #endif
351
 
 
353
  }
354
 
355
  void ggml_metal_free(struct ggml_metal_context * ctx) {
356
+ GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
357
  #define GGML_METAL_DEL_KERNEL(name) \
358
  [ctx->function_##name release]; \
359
  [ctx->pipeline_##name release];
 
363
  GGML_METAL_DEL_KERNEL(mul);
364
  GGML_METAL_DEL_KERNEL(mul_row);
365
  GGML_METAL_DEL_KERNEL(scale);
366
+ GGML_METAL_DEL_KERNEL(scale_4);
367
  GGML_METAL_DEL_KERNEL(silu);
368
  GGML_METAL_DEL_KERNEL(relu);
369
  GGML_METAL_DEL_KERNEL(gelu);
 
375
  GGML_METAL_DEL_KERNEL(get_rows_f16);
376
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
377
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
378
+ GGML_METAL_DEL_KERNEL(get_rows_q5_0);
379
+ GGML_METAL_DEL_KERNEL(get_rows_q5_1);
380
  GGML_METAL_DEL_KERNEL(get_rows_q8_0);
381
  GGML_METAL_DEL_KERNEL(get_rows_q2_K);
382
  GGML_METAL_DEL_KERNEL(get_rows_q3_K);
 
385
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
386
  GGML_METAL_DEL_KERNEL(rms_norm);
387
  GGML_METAL_DEL_KERNEL(norm);
388
+ GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
389
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
390
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
391
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
392
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
393
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
394
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
395
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
396
+ GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
397
+ GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
398
+ GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
399
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
400
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
401
+ GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
402
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
403
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
404
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
405
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
406
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
407
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
408
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
409
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
410
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
411
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
412
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
413
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
414
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
415
+ }
416
+ GGML_METAL_DEL_KERNEL(rope_f32);
417
+ GGML_METAL_DEL_KERNEL(rope_f16);
418
  GGML_METAL_DEL_KERNEL(alibi_f32);
419
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
420
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
421
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
422
+ GGML_METAL_DEL_KERNEL(concat);
423
+ GGML_METAL_DEL_KERNEL(sqr);
424
 
425
  #undef GGML_METAL_DEL_KERNEL
426
 
 
441
  void * data = NULL;
442
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
443
  if (result != 0) {
444
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
445
  return NULL;
446
  }
447
 
 
469
  // Metal buffer based on the host memory pointer
470
  //
471
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
472
+ //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
473
 
474
  const int64_t tsize = ggml_nbytes(t);
475
 
 
477
  for (int i = 0; i < ctx->n_buffers; ++i) {
478
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
479
 
480
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
481
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
482
  *offs = (size_t) ioffs;
483
 
484
+ //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
485
 
486
  return ctx->buffers[i].metal;
487
  }
488
  }
489
 
490
+ GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
491
 
492
  return nil;
493
  }
 
499
  size_t size,
500
  size_t max_size) {
501
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
502
+ GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
503
  return false;
504
  }
505
 
 
509
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
510
 
511
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
512
+ GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
513
  return false;
514
  }
515
  }
 
530
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
531
 
532
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
533
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
534
  return false;
535
  }
536
 
537
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
538
 
539
  ++ctx->n_buffers;
540
  } else {
 
554
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
555
 
556
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
557
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
558
  return false;
559
  }
560
 
561
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
562
  if (i + size_step < size) {
563
+ GGML_METAL_LOG_INFO("\n");
564
  }
565
 
566
  ++ctx->n_buffers;
 
568
  }
569
 
570
  #if TARGET_OS_OSX
571
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
572
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
573
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
574
 
575
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
576
+ GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
577
  } else {
578
+ GGML_METAL_LOG_INFO("\n");
579
  }
580
  #else
581
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
582
  #endif
583
  }
584
 
 
691
  }
692
 
693
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
694
+ GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
695
  }
696
  }
697
 
 
745
  continue;
746
  }
747
 
748
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
749
 
750
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
751
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
752
  struct ggml_tensor * dst = gf->nodes[i];
753
 
754
+ switch (dst->op) {
755
+ case GGML_OP_NONE:
756
+ case GGML_OP_RESHAPE:
757
+ case GGML_OP_VIEW:
758
+ case GGML_OP_TRANSPOSE:
759
+ case GGML_OP_PERMUTE:
760
+ {
761
+ // noop -> next node
762
+ } continue;
763
+ default:
764
+ {
765
+ } break;
766
+ }
767
+
768
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
769
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
770
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
 
803
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
804
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
805
 
806
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
807
  //if (src0) {
808
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
809
  // ggml_is_contiguous(src0), src0->name);
810
  //}
811
  //if (src1) {
812
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
813
  // ggml_is_contiguous(src1), src1->name);
814
  //}
815
  //if (dst) {
816
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
817
  // dst->name);
818
  //}
819
 
820
  switch (dst->op) {
821
+ case GGML_OP_CONCAT:
 
 
 
 
822
  {
823
+ const int64_t nb = ne00;
824
+
825
+ [encoder setComputePipelineState:ctx->pipeline_concat];
826
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
827
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
828
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
829
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
830
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
831
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
832
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
833
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
834
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
835
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
836
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
837
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
838
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
839
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
840
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
841
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
842
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
843
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
844
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
845
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
846
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
847
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
848
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
849
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
850
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
851
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
852
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
853
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
854
+
855
+ const int nth = MIN(1024, ne0);
856
+
857
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
858
  } break;
859
  case GGML_OP_ADD:
860
  {
861
  GGML_ASSERT(ggml_is_contiguous(src0));
862
  GGML_ASSERT(ggml_is_contiguous(src1));
863
 
864
+ bool bcast_row = false;
 
 
865
 
866
+ int64_t nb = ne00;
867
+
868
+ if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
869
  // src1 is a row
870
  GGML_ASSERT(ne11 == 1);
871
+
872
+ nb = ne00 / 4;
873
  [encoder setComputePipelineState:ctx->pipeline_add_row];
874
+
875
+ bcast_row = true;
876
  } else {
877
  [encoder setComputePipelineState:ctx->pipeline_add];
878
  }
879
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
880
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
881
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
882
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
883
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
884
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
885
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
886
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
887
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
888
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
889
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
890
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
891
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
892
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
893
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
894
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
895
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
896
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
897
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
898
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
899
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
900
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
901
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
902
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
903
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
904
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
905
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
906
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
907
+
908
+ if (bcast_row) {
909
+ const int64_t n = ggml_nelements(dst)/4;
910
+
911
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
912
+ } else {
913
+ const int nth = MIN(1024, ne0);
914
 
915
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
916
+ }
917
  } break;
918
  case GGML_OP_MUL:
919
  {
 
946
 
947
  const float scale = *(const float *) src1->data;
948
 
949
+ int64_t n = ggml_nelements(dst);
950
+
951
+ if (n % 4 == 0) {
952
+ n /= 4;
953
+ [encoder setComputePipelineState:ctx->pipeline_scale_4];
954
+ } else {
955
+ [encoder setComputePipelineState:ctx->pipeline_scale];
956
+ }
957
+
958
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
959
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
960
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
961
 
 
 
962
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
963
  } break;
964
  case GGML_OP_UNARY:
 
969
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
970
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
971
 
972
+ const int64_t n = ggml_nelements(dst);
973
+ GGML_ASSERT(n % 4 == 0);
974
 
975
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
976
  } break;
977
  case GGML_UNARY_OP_RELU:
978
  {
 
990
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
991
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
992
 
993
+ const int64_t n = ggml_nelements(dst);
994
+ GGML_ASSERT(n % 4 == 0);
995
 
996
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
997
  } break;
998
  default:
999
  {
1000
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1001
  GGML_ASSERT(false);
1002
  }
1003
  } break;
1004
+ case GGML_OP_SQR:
1005
+ {
1006
+ GGML_ASSERT(ggml_is_contiguous(src0));
1007
+
1008
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
1009
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1010
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1011
+
1012
+ const int64_t n = ggml_nelements(dst);
1013
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1014
+ } break;
1015
  case GGML_OP_SOFT_MAX:
1016
  {
1017
+ int nth = 32; // SIMD width
1018
 
1019
  if (ne00%4 == 0) {
1020
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1021
  } else {
1022
+ do {
1023
+ nth *= 2;
1024
+ } while (nth <= ne00 && nth <= 1024);
1025
+ nth /= 2;
1026
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1027
  }
1028
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1030
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1031
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1032
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1033
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1034
 
1035
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1036
  } break;
1037
  case GGML_OP_DIAG_MASK_INF:
1038
  {
 
1058
  } break;
1059
  case GGML_OP_MUL_MAT:
1060
  {
 
 
1061
  GGML_ASSERT(ne00 == ne10);
 
 
1062
  GGML_ASSERT(ne03 == ne13);
1063
 
1064
+ const uint gqa = ne12/ne02;
1065
+
1066
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1067
+ // to the matrix-vector kernel
1068
+ int ne11_mm_min = 1;
1069
+
1070
+ #if 0
1071
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1072
+ // these numbers do not translate to other devices or model sizes
1073
+ // TODO: need to find a better approach
1074
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1075
+ switch (src0t) {
1076
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1077
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1078
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1079
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1080
+ case GGML_TYPE_Q4_0:
1081
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1082
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1083
+ case GGML_TYPE_Q5_0: // not tested yet
1084
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1085
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1086
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1087
+ default: ne11_mm_min = 1; break;
1088
+ }
1089
+ }
1090
+ #endif
1091
+
1092
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1093
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1094
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1095
+ !ggml_is_transposed(src0) &&
1096
  !ggml_is_transposed(src1) &&
1097
  src1t == GGML_TYPE_F32 &&
1098
+ ne00 % 32 == 0 && ne00 >= 64 &&
1099
+ ne11 > ne11_mm_min) {
1100
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1101
  switch (src0->type) {
1102
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
1103
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1104
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1105
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1106
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1107
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1108
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1109
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1110
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
 
1128
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1129
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1130
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1131
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1132
  } else {
1133
  int nth0 = 32;
1134
  int nth1 = 1;
1135
  int nrows = 1;
1136
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1137
 
1138
  // use custom matrix x vector kernel
1139
  switch (src0t) {
1140
  case GGML_TYPE_F32:
1141
  {
1142
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1143
  nrows = 4;
1144
  } break;
1145
  case GGML_TYPE_F16:
 
1147
  nth0 = 32;
1148
  nth1 = 1;
1149
  if (ne11 * ne12 < 4) {
1150
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1151
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1152
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1153
  nrows = ne11;
1154
  } else {
1155
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1156
  nrows = 4;
1157
  }
1158
  } break;
 
1163
 
1164
  nth0 = 8;
1165
  nth1 = 8;
1166
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1167
  } break;
1168
  case GGML_TYPE_Q4_1:
1169
  {
 
1172
 
1173
  nth0 = 8;
1174
  nth1 = 8;
1175
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1176
+ } break;
1177
+ case GGML_TYPE_Q5_0:
1178
+ {
1179
+ GGML_ASSERT(ne02 == 1);
1180
+ GGML_ASSERT(ne12 == 1);
1181
+
1182
+ nth0 = 8;
1183
+ nth1 = 8;
1184
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1185
+ } break;
1186
+ case GGML_TYPE_Q5_1:
1187
+ {
1188
+ GGML_ASSERT(ne02 == 1);
1189
+ GGML_ASSERT(ne12 == 1);
1190
+
1191
+ nth0 = 8;
1192
+ nth1 = 8;
1193
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1194
  } break;
1195
  case GGML_TYPE_Q8_0:
1196
  {
 
1199
 
1200
  nth0 = 8;
1201
  nth1 = 8;
1202
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1203
  } break;
1204
  case GGML_TYPE_Q2_K:
1205
  {
 
1208
 
1209
  nth0 = 2;
1210
  nth1 = 32;
1211
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1212
  } break;
1213
  case GGML_TYPE_Q3_K:
1214
  {
 
1217
 
1218
  nth0 = 2;
1219
  nth1 = 32;
1220
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1221
  } break;
1222
  case GGML_TYPE_Q4_K:
1223
  {
 
1226
 
1227
  nth0 = 4; //1;
1228
  nth1 = 8; //32;
1229
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1230
  } break;
1231
  case GGML_TYPE_Q5_K:
1232
  {
 
1235
 
1236
  nth0 = 2;
1237
  nth1 = 32;
1238
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1239
  } break;
1240
  case GGML_TYPE_Q6_K:
1241
  {
 
1244
 
1245
  nth0 = 2;
1246
  nth1 = 32;
1247
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1248
  } break;
1249
  default:
1250
  {
1251
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1252
  GGML_ASSERT(false && "not implemented");
1253
  }
1254
  };
 
1272
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1273
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1274
 
1275
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1276
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1277
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1278
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1279
  }
1280
  else if (src0t == GGML_TYPE_Q4_K) {
 
1305
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1306
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1307
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1308
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1309
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1310
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1311
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1312
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
 
1329
  } break;
1330
  case GGML_OP_RMS_NORM:
1331
  {
1332
+ GGML_ASSERT(ne00 % 4 == 0);
1333
+
1334
  float eps;
1335
  memcpy(&eps, dst->op_params, sizeof(float));
1336
 
1337
+ const int nth = MIN(512, ne00);
1338
 
1339
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1340
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1353
  float eps;
1354
  memcpy(&eps, dst->op_params, sizeof(float));
1355
 
1356
+ const int nth = MIN(256, ne00);
1357
 
1358
  [encoder setComputePipelineState:ctx->pipeline_norm];
1359
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1371
  {
1372
  GGML_ASSERT((src0t == GGML_TYPE_F32));
1373
 
1374
+ const int nth = MIN(1024, ne00);
1375
+
1376
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1377
  const int n_head = ((int32_t *) dst->op_params)[1];
1378
  float max_bias;
1379
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1380
 
 
 
 
 
1381
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1382
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1383
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1384
 
1385
  [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1386
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1401
  [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1402
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1403
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1404
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1405
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1406
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1407
 
1408
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1409
  } break;
1410
  case GGML_OP_ROPE:
1411
  {
1412
+ GGML_ASSERT(ne10 == ne02);
 
 
1413
 
1414
+ const int nth = MIN(1024, ne00);
 
 
 
1415
 
1416
+ const int n_past = ((int32_t *) dst->op_params)[0];
1417
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1418
+ const int mode = ((int32_t *) dst->op_params)[2];
1419
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1420
+
1421
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1422
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1423
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1424
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1425
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1426
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1427
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1428
+
1429
+ switch (src0->type) {
1430
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1431
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1432
+ default: GGML_ASSERT(false);
1433
+ };
1434
+
1435
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1436
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1437
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1438
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1439
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1440
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1441
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1442
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1443
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1444
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1445
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1446
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1447
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1448
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1449
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1450
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1451
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1452
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1453
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1454
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1455
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1456
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1457
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1458
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
1459
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
1460
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
1461
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
1462
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
1463
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1464
 
1465
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1466
  } break;
1467
  case GGML_OP_DUP:
1468
  case GGML_OP_CPY:
1469
  case GGML_OP_CONT:
1470
  {
1471
+ const int nth = MIN(1024, ne00);
1472
 
1473
  switch (src0t) {
1474
  case GGML_TYPE_F32:
 
1513
  } break;
1514
  default:
1515
  {
1516
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1517
  GGML_ASSERT(false);
1518
  }
1519
  }
 
1538
 
1539
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1540
  if (status != MTLCommandBufferStatusCompleted) {
1541
+ GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1542
  GGML_ASSERT(false);
1543
  }
1544
  }
1545
 
1546
  }
1547
  }
1548
+
1549
+ ////////////////////////////////////////////////////////////////////////////////
1550
+
1551
+ // backend interface
1552
+
1553
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1554
+ return "Metal";
1555
+
1556
+ UNUSED(backend);
1557
+ }
1558
+
1559
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
1560
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1561
+ ggml_metal_free(ctx);
1562
+ free(backend);
1563
+ }
1564
+
1565
+ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1566
+ return (void *)buffer->context;
1567
+ }
1568
+
1569
+ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1570
+ free(buffer->context);
1571
+ UNUSED(buffer);
1572
+ }
1573
+
1574
+ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1575
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1576
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1577
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1578
+ /* .init_tensor = */ NULL, // no initialization required
1579
+ /* .free_tensor = */ NULL, // no cleanup required
1580
+ };
1581
+
1582
+ static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1583
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1584
+
1585
+ void * data = ggml_metal_host_malloc(size);
1586
+
1587
+ // TODO: set proper name of the buffers
1588
+ ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1589
+
1590
+ return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1591
+ }
1592
+
1593
+ static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1594
+ return 32;
1595
+ UNUSED(backend);
1596
+ }
1597
+
1598
+ static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1599
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1600
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1601
+
1602
+ memcpy((char *)tensor->data + offset, data, size);
1603
+
1604
+ UNUSED(backend);
1605
+ }
1606
+
1607
+ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1608
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1609
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1610
+
1611
+ memcpy(data, (const char *)tensor->data + offset, size);
1612
+
1613
+ UNUSED(backend);
1614
+ }
1615
+
1616
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1617
+ UNUSED(backend);
1618
+ }
1619
+
1620
+ static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1621
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1622
+
1623
+ UNUSED(backend);
1624
+ }
1625
+
1626
+ static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1627
+ ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
1628
+
1629
+ UNUSED(backend);
1630
+ }
1631
+
1632
+ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1633
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
1634
+
1635
+ ggml_metal_graph_compute(metal_ctx, cgraph);
1636
+ }
1637
+
1638
+ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1639
+ return true;
1640
+ UNUSED(backend);
1641
+ UNUSED(op);
1642
+ }
1643
+
1644
+ static struct ggml_backend_i metal_backend_i = {
1645
+ /* .get_name = */ ggml_backend_metal_name,
1646
+ /* .free = */ ggml_backend_metal_free,
1647
+ /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1648
+ /* .get_alignment = */ ggml_backend_metal_get_alignment,
1649
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1650
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1651
+ /* .synchronize = */ ggml_backend_metal_synchronize,
1652
+ /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1653
+ /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1654
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1655
+ /* .graph_plan_free = */ NULL,
1656
+ /* .graph_plan_compute = */ NULL,
1657
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
1658
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1659
+ };
1660
+
1661
+ ggml_backend_t ggml_backend_metal_init(void) {
1662
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
1663
+
1664
+ ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1665
+
1666
+ ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1667
+
1668
+ *metal_backend = (struct ggml_backend) {
1669
+ /* .interface = */ metal_backend_i,
1670
+ /* .context = */ ctx,
1671
+ };
1672
+
1673
+ return metal_backend;
1674
+ }
1675
+
1676
+ bool ggml_backend_is_metal(ggml_backend_t backend) {
1677
+ return backend->iface.get_name == ggml_backend_metal_name;
1678
+ }
1679
+
1680
+ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
1681
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1682
+
1683
+ ggml_metal_set_n_cb(ctx, n_cb);
1684
+ }
ggml-metal.metal CHANGED
@@ -13,23 +13,85 @@ typedef struct {
13
 
14
  #define QK4_1 32
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
  } block_q4_1;
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  #define QK8_0 32
22
  typedef struct {
23
  half d; // delta
24
  int8_t qs[QK8_0]; // quants
25
  } block_q8_0;
26
 
 
 
 
27
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }
34
 
35
  // assumption: src1 is a row
@@ -38,7 +100,7 @@ kernel void kernel_add_row(
38
  device const float4 * src0,
39
  device const float4 * src1,
40
  device float4 * dst,
41
- constant int64_t & nb,
42
  uint tpig[[thread_position_in_grid]]) {
43
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
  }
@@ -63,9 +125,17 @@ kernel void kernel_mul_row(
63
  }
64
 
65
  kernel void kernel_scale(
 
 
 
 
 
 
 
 
66
  device const float4 * src0,
67
  device float4 * dst,
68
- constant float & scale,
69
  uint tpig[[thread_position_in_grid]]) {
70
  dst[tpig] = src0[tpig] * scale;
71
  }
@@ -85,6 +155,13 @@ kernel void kernel_relu(
85
  dst[tpig] = max(0.0f, src0[tpig]);
86
  }
87
 
 
 
 
 
 
 
 
88
  constant float GELU_COEF_A = 0.044715f;
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
 
@@ -107,36 +184,73 @@ kernel void kernel_soft_max(
107
  constant int64_t & ne00,
108
  constant int64_t & ne01,
109
  constant int64_t & ne02,
110
- uint3 tgpig[[threadgroup_position_in_grid]],
111
- uint3 tpitg[[thread_position_in_threadgroup]],
112
- uint3 ntg[[threads_per_threadgroup]]) {
113
- const int64_t i03 = tgpig[2];
114
- const int64_t i02 = tgpig[1];
115
- const int64_t i01 = tgpig[0];
 
 
 
116
 
117
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
 
120
  // parallel max
121
- float lmax = psrc0[tpitg[0]];
122
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
 
123
  lmax = MAX(lmax, psrc0[i00]);
124
  }
125
- const float max = simd_max(lmax);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  // parallel sum
128
  float lsum = 0.0f;
129
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
130
  const float exp_psrc0 = exp(psrc0[i00] - max);
131
  lsum += exp_psrc0;
132
  // Remember the result of exp here. exp is expensive, so we really do not
133
- // whish to compute it twice.
134
  pdst[i00] = exp_psrc0;
135
  }
136
 
137
- const float sum = simd_sum(lsum);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
  pdst[i00] /= sum;
141
  }
142
  }
@@ -147,37 +261,73 @@ kernel void kernel_soft_max_4(
147
  constant int64_t & ne00,
148
  constant int64_t & ne01,
149
  constant int64_t & ne02,
150
- uint3 tgpig[[threadgroup_position_in_grid]],
151
- uint3 tpitg[[thread_position_in_threadgroup]],
152
- uint3 ntg[[threads_per_threadgroup]]) {
153
- const int64_t i03 = tgpig[2];
154
- const int64_t i02 = tgpig[1];
155
- const int64_t i01 = tgpig[0];
 
 
 
156
 
157
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
 
160
  // parallel max
161
- float4 lmax4 = psrc4[tpitg[0]];
162
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
 
163
  lmax4 = fmax(lmax4, psrc4[i00]);
164
  }
165
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
166
 
167
- const float max = simd_max(lmax);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  // parallel sum
170
  float4 lsum4 = 0.0f;
171
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
  const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
  lsum4 += exp_psrc4;
174
  pdst4[i00] = exp_psrc4;
175
  }
176
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
177
 
178
- const float sum = simd_sum(lsum);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
  pdst4[i00] /= sum;
182
  }
183
  }
@@ -197,7 +347,7 @@ kernel void kernel_diag_mask_inf(
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
198
  } else {
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
- }
201
  }
202
 
203
  kernel void kernel_diag_mask_inf_8(
@@ -291,10 +441,11 @@ kernel void kernel_rms_norm(
291
  uint sgitg[[simdgroup_index_in_threadgroup]],
292
  uint tiisg[[thread_index_in_simdgroup]],
293
  uint ntg[[threads_per_threadgroup]]) {
294
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
295
- device const float * x_scalar = (device const float *) x;
296
- float4 sumf=0;
297
- float all_sum=0;
 
298
 
299
  // parallel sum
300
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -307,6 +458,7 @@ kernel void kernel_rms_norm(
307
  }
308
 
309
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
310
  // broadcast, simd group number is ntg / 32
311
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
312
  if (tpitg < i) {
@@ -314,7 +466,9 @@ kernel void kernel_rms_norm(
314
  }
315
  }
316
  if (tpitg == 0) {
317
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
 
 
318
  sum[0] /= ne00;
319
  }
320
 
@@ -329,7 +483,9 @@ kernel void kernel_rms_norm(
329
  y[i00] = x[i00] * scale;
330
  }
331
  if (tpitg == 0) {
332
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
 
 
333
  }
334
  }
335
 
@@ -339,8 +495,11 @@ kernel void kernel_rms_norm(
339
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
340
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
341
  float d = qb_curr->d;
 
342
  float2 acc = 0.f;
 
343
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
 
344
  for (int i = 0; i < 8; i+=2) {
345
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
346
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -357,8 +516,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
357
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
358
  float d = qb_curr->d;
359
  float m = qb_curr->m;
360
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
361
  float2 acc = 0.f;
 
 
 
362
  for (int i = 0; i < 8; i+=2) {
363
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
364
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -368,9 +530,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
368
  return d * (acc[0] + acc[1]) + sumy * m;
369
  }
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  // putting them in the kernel cause a significant performance penalty
372
- #define N_DST 4 // each SIMD group works on 4 rows
373
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
374
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
375
  //Note: This is a template, but strictly speaking it only applies to
376
  // quantizations where the block size is 32. It also does not
@@ -381,18 +586,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
381
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
382
  uint3 tgpig, uint tiisg, uint sgitg) {
383
  const int nb = ne00/QK4_0;
 
384
  const int r0 = tgpig.x;
385
  const int r1 = tgpig.y;
386
  const int im = tgpig.z;
 
387
  const int first_row = (r0 * nsg + sgitg) * nr;
 
388
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
 
389
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
390
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
391
- float yl[16]; // src1 vector cache
392
- float sumf[nr]={0.f};
393
 
394
- const int ix = tiisg/2;
395
- const int il = 8*(tiisg%2);
 
 
 
396
 
397
  device const float * yb = y + ix * QK4_0 + il;
398
 
@@ -403,6 +613,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
403
  sumy += yb[i] + yb[i+1];
404
  yl[i+0] = yb[i+ 0];
405
  yl[i+1] = yb[i+ 1]/256.f;
 
406
  sumy += yb[i+16] + yb[i+17];
407
  yl[i+8] = yb[i+16]/16.f;
408
  yl[i+9] = yb[i+17]/4096.f;
@@ -418,12 +629,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
418
  for (int row = 0; row < nr; ++row) {
419
  const float tot = simd_sum(sumf[row]);
420
  if (tiisg == 0 && first_row + row < ne01) {
421
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
422
  }
423
  }
424
  }
425
 
426
- kernel void kernel_mul_mat_q4_0_f32(
427
  device const void * src0,
428
  device const float * src1,
429
  device float * dst,
@@ -436,12 +647,12 @@ kernel void kernel_mul_mat_q4_0_f32(
436
  constant int64_t & ne1[[buffer(16)]],
437
  constant uint & gqa[[buffer(17)]],
438
  uint3 tgpig[[threadgroup_position_in_grid]],
439
- uint tiisg[[thread_index_in_simdgroup]],
440
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
441
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
442
  }
443
 
444
- kernel void kernel_mul_mat_q4_1_f32(
445
  device const void * src0,
446
  device const float * src1,
447
  device float * dst,
@@ -459,9 +670,46 @@ kernel void kernel_mul_mat_q4_1_f32(
459
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
460
  }
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  #define NB_Q8_0 8
463
 
464
- kernel void kernel_mul_mat_q8_0_f32(
465
  device const void * src0,
466
  device const float * src1,
467
  device float * dst,
@@ -525,7 +773,7 @@ kernel void kernel_mul_mat_q8_0_f32(
525
 
526
  #define N_F32_F32 4
527
 
528
- kernel void kernel_mul_mat_f32_f32(
529
  device const char * src0,
530
  device const char * src1,
531
  device float * dst,
@@ -596,7 +844,7 @@ kernel void kernel_mul_mat_f32_f32(
596
  }
597
  }
598
 
599
- kernel void kernel_mul_mat_f16_f32_1row(
600
  device const char * src0,
601
  device const char * src1,
602
  device float * dst,
@@ -615,7 +863,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
615
  constant int64_t & ne0,
616
  constant int64_t & ne1,
617
  uint3 tgpig[[threadgroup_position_in_grid]],
618
- uint tiisg[[thread_index_in_simdgroup]]) {
619
 
620
  const int64_t r0 = tgpig.x;
621
  const int64_t r1 = tgpig.y;
@@ -650,7 +898,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
650
 
651
  #define N_F16_F32 4
652
 
653
- kernel void kernel_mul_mat_f16_f32(
654
  device const char * src0,
655
  device const char * src1,
656
  device float * dst,
@@ -722,7 +970,7 @@ kernel void kernel_mul_mat_f16_f32(
722
  }
723
 
724
  // Assumes row size (ne00) is a multiple of 4
725
- kernel void kernel_mul_mat_f16_f32_l4(
726
  device const char * src0,
727
  device const char * src1,
728
  device float * dst,
@@ -783,7 +1031,9 @@ kernel void kernel_alibi_f32(
783
  constant uint64_t & nb1,
784
  constant uint64_t & nb2,
785
  constant uint64_t & nb3,
786
- constant float & m0,
 
 
787
  uint3 tgpig[[threadgroup_position_in_grid]],
788
  uint3 tpitg[[thread_position_in_threadgroup]],
789
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -799,37 +1049,122 @@ kernel void kernel_alibi_f32(
799
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
800
 
801
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
- float m_k = pow(m0, i2 + 1);
 
 
 
 
 
803
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
804
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
805
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
806
  }
807
  }
808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809
  kernel void kernel_rope(
810
- device const void * src0,
811
- device float * dst,
812
- constant int64_t & ne00,
813
- constant int64_t & ne01,
814
- constant int64_t & ne02,
815
- constant int64_t & ne03,
816
- constant uint64_t & nb00,
817
- constant uint64_t & nb01,
818
- constant uint64_t & nb02,
819
- constant uint64_t & nb03,
820
- constant int64_t & ne0,
821
- constant int64_t & ne1,
822
- constant int64_t & ne2,
823
- constant int64_t & ne3,
824
- constant uint64_t & nb0,
825
- constant uint64_t & nb1,
826
- constant uint64_t & nb2,
827
- constant uint64_t & nb3,
828
- constant int & n_past,
829
- constant int & n_dims,
830
- constant int & mode,
831
- constant float & freq_base,
832
- constant float & freq_scale,
 
 
 
 
 
 
833
  uint tiitg[[thread_index_in_threadgroup]],
834
  uint3 tptg[[threads_per_threadgroup]],
835
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -839,23 +1174,28 @@ kernel void kernel_rope(
839
 
840
  const bool is_neox = mode & 2;
841
 
842
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
 
843
 
844
- const float theta_0 = freq_scale * (float)p;
 
 
 
 
845
  const float inv_ndims = -1.f/n_dims;
846
 
847
  if (!is_neox) {
848
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
849
 
850
  const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
851
- const float cos_theta = cos(theta);
852
- const float sin_theta = sin(theta);
853
 
854
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
855
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
856
 
857
- const float x0 = src[0];
858
- const float x1 = src[1];
859
 
860
  dst_data[0] = x0*cos_theta - x1*sin_theta;
861
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -864,14 +1204,17 @@ kernel void kernel_rope(
864
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
865
  for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
866
 
867
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
868
- const float cos_theta = cos(theta);
869
- const float sin_theta = sin(theta);
 
 
 
870
 
871
  const int64_t i0 = ib*n_dims + ic/2;
872
 
873
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
874
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
875
 
876
  const float x0 = src[0];
877
  const float x1 = src[n_dims/2];
@@ -883,6 +1226,9 @@ kernel void kernel_rope(
883
  }
884
  }
885
 
 
 
 
886
  kernel void kernel_cpy_f16_f16(
887
  device const half * src0,
888
  device half * dst,
@@ -1008,6 +1354,62 @@ kernel void kernel_cpy_f32_f32(
1008
  }
1009
  }
1010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  //============================================ k-quants ======================================================
1012
 
1013
  #ifndef QK_K
@@ -1100,7 +1502,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1100
 
1101
  //====================================== dot products =========================
1102
 
1103
- kernel void kernel_mul_mat_q2_K_f32(
1104
  device const void * src0,
1105
  device const float * src1,
1106
  device float * dst,
@@ -1244,7 +1646,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1244
  }
1245
 
1246
  #if QK_K == 256
1247
- kernel void kernel_mul_mat_q3_K_f32(
1248
  device const void * src0,
1249
  device const float * src1,
1250
  device float * dst,
@@ -1273,8 +1675,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1273
 
1274
  float yl[32];
1275
 
1276
- const uint16_t kmask1 = 0x3030;
1277
- const uint16_t kmask2 = 0x0f0f;
1278
 
1279
  const int tid = tiisg/4;
1280
  const int ix = tiisg%4;
@@ -1396,7 +1798,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1396
  }
1397
  }
1398
  #else
1399
- kernel void kernel_mul_mat_q3_K_f32(
1400
  device const void * src0,
1401
  device const float * src1,
1402
  device float * dst,
@@ -1467,7 +1869,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1467
  #endif
1468
 
1469
  #if QK_K == 256
1470
- kernel void kernel_mul_mat_q4_K_f32(
1471
  device const void * src0,
1472
  device const float * src1,
1473
  device float * dst,
@@ -1573,7 +1975,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1573
  }
1574
  }
1575
  #else
1576
- kernel void kernel_mul_mat_q4_K_f32(
1577
  device const void * src0,
1578
  device const float * src1,
1579
  device float * dst,
@@ -1662,7 +2064,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1662
  }
1663
  #endif
1664
 
1665
- kernel void kernel_mul_mat_q5_K_f32(
1666
  device const void * src0,
1667
  device const float * src1,
1668
  device float * dst,
@@ -1835,7 +2237,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1835
 
1836
  }
1837
 
1838
- kernel void kernel_mul_mat_q6_K_f32(
1839
  device const void * src0,
1840
  device const float * src1,
1841
  device float * dst,
@@ -1984,6 +2386,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1984
  }
1985
  }
1986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1987
  template <typename type4x4>
1988
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1989
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2173,7 +2631,7 @@ kernel void kernel_get_rows(
2173
  }
2174
 
2175
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2176
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2177
  #define BLOCK_SIZE_K 32
2178
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2179
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2210,9 +2668,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2210
  const uint r0 = tgpig.y;
2211
  const uint r1 = tgpig.x;
2212
  const uint im = tgpig.z;
 
2213
  // if this block is of 64x32 shape or smaller
2214
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2215
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
 
2216
  // a thread shouldn't load data outside of the matrix
2217
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2218
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2236,26 +2696,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2236
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2237
 
2238
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2239
- //load data and store to threadgroup memory
2240
  half4x4 temp_a;
2241
  dequantize_func(x, il, temp_a);
2242
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
2243
  #pragma unroll(16)
2244
  for (int i = 0; i < 16; i++) {
2245
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2246
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2247
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2248
  }
2249
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2250
- = *((device float2x4 *)y);
 
2251
  il = (il + 2 < nl) ? il + 2 : il % 2;
2252
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2253
  y += BLOCK_SIZE_K;
2254
 
2255
  threadgroup_barrier(mem_flags::mem_threadgroup);
2256
- //load matrices from threadgroup memory and conduct outer products
 
2257
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2258
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
 
2259
  #pragma unroll(4)
2260
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2261
  #pragma unroll(4)
@@ -2270,6 +2734,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2270
 
2271
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2272
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
 
2273
  #pragma unroll(8)
2274
  for (int i = 0; i < 8; i++){
2275
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2278,25 +2743,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2278
  }
2279
 
2280
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2281
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2282
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2283
  for (int i = 0; i < 8; i++) {
2284
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2285
  }
2286
  } else {
2287
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2288
  threadgroup_barrier(mem_flags::mem_threadgroup);
2289
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2290
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2291
  for (int i = 0; i < 8; i++) {
2292
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2293
  }
2294
 
2295
  threadgroup_barrier(mem_flags::mem_threadgroup);
2296
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2297
- if (sgitg==0) {
 
2298
  for (int i = 0; i < n_rows; i++) {
2299
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2300
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2301
  }
2302
  }
@@ -2317,6 +2783,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2317
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2318
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2319
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
 
 
2320
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2321
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2322
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2345,6 +2813,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2345
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2346
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2347
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
 
 
2348
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2349
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2350
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
 
13
 
14
  #define QK4_1 32
15
  typedef struct {
16
+ half d; // delta
17
+ half m; // min
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
  } block_q4_1;
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
36
  #define QK8_0 32
37
  typedef struct {
38
  half d; // delta
39
  int8_t qs[QK8_0]; // quants
40
  } block_q8_0;
41
 
42
+ // general-purpose kernel for addition of two tensors
43
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
44
+ // cons: not very efficient
45
  kernel void kernel_add(
46
+ device const char * src0,
47
+ device const char * src1,
48
+ device char * dst,
49
+ constant int64_t & ne00,
50
+ constant int64_t & ne01,
51
+ constant int64_t & ne02,
52
+ constant int64_t & ne03,
53
+ constant int64_t & nb00,
54
+ constant int64_t & nb01,
55
+ constant int64_t & nb02,
56
+ constant int64_t & nb03,
57
+ constant int64_t & ne10,
58
+ constant int64_t & ne11,
59
+ constant int64_t & ne12,
60
+ constant int64_t & ne13,
61
+ constant int64_t & nb10,
62
+ constant int64_t & nb11,
63
+ constant int64_t & nb12,
64
+ constant int64_t & nb13,
65
+ constant int64_t & ne0,
66
+ constant int64_t & ne1,
67
+ constant int64_t & ne2,
68
+ constant int64_t & ne3,
69
+ constant int64_t & nb0,
70
+ constant int64_t & nb1,
71
+ constant int64_t & nb2,
72
+ constant int64_t & nb3,
73
+ uint3 tgpig[[threadgroup_position_in_grid]],
74
+ uint3 tpitg[[thread_position_in_threadgroup]],
75
+ uint3 ntg[[threads_per_threadgroup]]) {
76
+ const int64_t i03 = tgpig.z;
77
+ const int64_t i02 = tgpig.y;
78
+ const int64_t i01 = tgpig.x;
79
+
80
+ const int64_t i13 = i03 % ne13;
81
+ const int64_t i12 = i02 % ne12;
82
+ const int64_t i11 = i01 % ne11;
83
+
84
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
85
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
86
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
87
+
88
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
89
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
90
+
91
+ src0_ptr += ntg.x*nb00;
92
+ src1_ptr += ntg.x*nb10;
93
+ dst_ptr += ntg.x*nb0;
94
+ }
95
  }
96
 
97
  // assumption: src1 is a row
 
100
  device const float4 * src0,
101
  device const float4 * src1,
102
  device float4 * dst,
103
+ constant int64_t & nb [[buffer(27)]],
104
  uint tpig[[thread_position_in_grid]]) {
105
  dst[tpig] = src0[tpig] + src1[tpig % nb];
106
  }
 
125
  }
126
 
127
  kernel void kernel_scale(
128
+ device const float * src0,
129
+ device float * dst,
130
+ constant float & scale,
131
+ uint tpig[[thread_position_in_grid]]) {
132
+ dst[tpig] = src0[tpig] * scale;
133
+ }
134
+
135
+ kernel void kernel_scale_4(
136
  device const float4 * src0,
137
  device float4 * dst,
138
+ constant float & scale,
139
  uint tpig[[thread_position_in_grid]]) {
140
  dst[tpig] = src0[tpig] * scale;
141
  }
 
155
  dst[tpig] = max(0.0f, src0[tpig]);
156
  }
157
 
158
+ kernel void kernel_sqr(
159
+ device const float * src0,
160
+ device float * dst,
161
+ uint tpig[[thread_position_in_grid]]) {
162
+ dst[tpig] = src0[tpig] * src0[tpig];
163
+ }
164
+
165
  constant float GELU_COEF_A = 0.044715f;
166
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
167
 
 
184
  constant int64_t & ne00,
185
  constant int64_t & ne01,
186
  constant int64_t & ne02,
187
+ threadgroup float * buf [[threadgroup(0)]],
188
+ uint tgpig[[threadgroup_position_in_grid]],
189
+ uint tpitg[[thread_position_in_threadgroup]],
190
+ uint sgitg[[simdgroup_index_in_threadgroup]],
191
+ uint tiisg[[thread_index_in_simdgroup]],
192
+ uint ntg[[threads_per_threadgroup]]) {
193
+ const int64_t i03 = (tgpig) / (ne02*ne01);
194
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
196
 
197
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
198
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
199
 
200
  // parallel max
201
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
+
203
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
204
  lmax = MAX(lmax, psrc0[i00]);
205
  }
206
+
207
+ float max = simd_max(lmax);
208
+ if (tiisg == 0) {
209
+ buf[sgitg] = max;
210
+ }
211
+
212
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
+
214
+ // broadcast, simd group number is ntg / 32
215
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
+ if (tpitg < i) {
217
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
+ }
219
+ }
220
+
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+
223
+ max = buf[0];
224
 
225
  // parallel sum
226
  float lsum = 0.0f;
227
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
228
  const float exp_psrc0 = exp(psrc0[i00] - max);
229
  lsum += exp_psrc0;
230
  // Remember the result of exp here. exp is expensive, so we really do not
231
+ // wish to compute it twice.
232
  pdst[i00] = exp_psrc0;
233
  }
234
 
235
+ float sum = simd_sum(lsum);
236
+ if (tiisg == 0) {
237
+ buf[sgitg] = sum;
238
+ }
239
+
240
+ threadgroup_barrier(mem_flags::mem_threadgroup);
241
+
242
+ // broadcast, simd group number is ntg / 32
243
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
+ if (tpitg < i) {
245
+ buf[tpitg] += buf[tpitg + i];
246
+ }
247
+ }
248
+
249
+ threadgroup_barrier(mem_flags::mem_threadgroup);
250
+
251
+ sum = buf[0];
252
 
253
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
254
  pdst[i00] /= sum;
255
  }
256
  }
 
261
  constant int64_t & ne00,
262
  constant int64_t & ne01,
263
  constant int64_t & ne02,
264
+ threadgroup float * buf [[threadgroup(0)]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ const int64_t i03 = (tgpig) / (ne02*ne01);
271
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
273
 
274
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
275
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
276
 
277
  // parallel max
278
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
+
280
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
281
  lmax4 = fmax(lmax4, psrc4[i00]);
282
  }
 
283
 
284
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
+ float max = simd_max(lmax);
286
+ if (tiisg == 0) {
287
+ buf[sgitg] = max;
288
+ }
289
+
290
+ threadgroup_barrier(mem_flags::mem_threadgroup);
291
+
292
+ // broadcast, simd group number is ntg / 32
293
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
+ if (tpitg < i) {
295
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
+ }
297
+ }
298
+
299
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
+
301
+ max = buf[0];
302
 
303
  // parallel sum
304
  float4 lsum4 = 0.0f;
305
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
306
  const float4 exp_psrc4 = exp(psrc4[i00] - max);
307
  lsum4 += exp_psrc4;
308
  pdst4[i00] = exp_psrc4;
309
  }
 
310
 
311
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
+ float sum = simd_sum(lsum);
313
+ if (tiisg == 0) {
314
+ buf[sgitg] = sum;
315
+ }
316
+
317
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
+
319
+ // broadcast, simd group number is ntg / 32
320
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
+ if (tpitg < i) {
322
+ buf[tpitg] += buf[tpitg + i];
323
+ }
324
+ }
325
+
326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
327
+
328
+ sum = buf[0];
329
 
330
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
331
  pdst4[i00] /= sum;
332
  }
333
  }
 
347
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
348
  } else {
349
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
350
+ }
351
  }
352
 
353
  kernel void kernel_diag_mask_inf_8(
 
441
  uint sgitg[[simdgroup_index_in_threadgroup]],
442
  uint tiisg[[thread_index_in_simdgroup]],
443
  uint ntg[[threads_per_threadgroup]]) {
444
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
445
+ device const float * x_scalar = (device const float *) x;
446
+
447
+ float4 sumf = 0;
448
+ float all_sum = 0;
449
 
450
  // parallel sum
451
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
 
458
  }
459
 
460
  threadgroup_barrier(mem_flags::mem_threadgroup);
461
+
462
  // broadcast, simd group number is ntg / 32
463
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
464
  if (tpitg < i) {
 
466
  }
467
  }
468
  if (tpitg == 0) {
469
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
470
+ sum[0] += x_scalar[i];
471
+ }
472
  sum[0] /= ne00;
473
  }
474
 
 
483
  y[i00] = x[i00] * scale;
484
  }
485
  if (tpitg == 0) {
486
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
487
+ y_scalar[i00] = x_scalar[i00] * scale;
488
+ }
489
  }
490
  }
491
 
 
495
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
496
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
497
  float d = qb_curr->d;
498
+
499
  float2 acc = 0.f;
500
+
501
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
502
+
503
  for (int i = 0; i < 8; i+=2) {
504
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
505
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
 
516
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
517
  float d = qb_curr->d;
518
  float m = qb_curr->m;
519
+
520
  float2 acc = 0.f;
521
+
522
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
523
+
524
  for (int i = 0; i < 8; i+=2) {
525
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
526
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
 
530
  return d * (acc[0] + acc[1]) + sumy * m;
531
  }
532
 
533
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
534
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
535
+ // we assume that the yl's have been multiplied with the appropriate scale factor
536
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
537
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
538
+ float d = qb_curr->d;
539
+
540
+ float2 acc = 0.f;
541
+
542
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
543
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
544
+
545
+ for (int i = 0; i < 8; i+=2) {
546
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
547
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
548
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
549
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
550
+ }
551
+ return d * (sumy * -16.f + acc[0] + acc[1]);
552
+ }
553
+
554
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
555
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
556
+ // we assume that the yl's have been multiplied with the appropriate scale factor
557
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
558
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
559
+ float d = qb_curr->d;
560
+ float m = qb_curr->m;
561
+
562
+ float2 acc = 0.f;
563
+
564
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
565
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
566
+
567
+ for (int i = 0; i < 8; i+=2) {
568
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
569
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
570
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
571
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
572
+ }
573
+ return d * (acc[0] + acc[1]) + sumy * m;
574
+ }
575
+
576
  // putting them in the kernel cause a significant performance penalty
577
+ #define N_DST 4 // each SIMD group works on 4 rows
578
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
579
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
580
  //Note: This is a template, but strictly speaking it only applies to
581
  // quantizations where the block size is 32. It also does not
 
586
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
  uint3 tgpig, uint tiisg, uint sgitg) {
588
  const int nb = ne00/QK4_0;
589
+
590
  const int r0 = tgpig.x;
591
  const int r1 = tgpig.y;
592
  const int im = tgpig.z;
593
+
594
  const int first_row = (r0 * nsg + sgitg) * nr;
595
+
596
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
597
+
598
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
600
 
601
+ float yl[16]; // src1 vector cache
602
+ float sumf[nr] = {0.f};
603
+
604
+ const int ix = (tiisg/2);
605
+ const int il = (tiisg%2)*8;
606
 
607
  device const float * yb = y + ix * QK4_0 + il;
608
 
 
613
  sumy += yb[i] + yb[i+1];
614
  yl[i+0] = yb[i+ 0];
615
  yl[i+1] = yb[i+ 1]/256.f;
616
+
617
  sumy += yb[i+16] + yb[i+17];
618
  yl[i+8] = yb[i+16]/16.f;
619
  yl[i+9] = yb[i+17]/4096.f;
 
629
  for (int row = 0; row < nr; ++row) {
630
  const float tot = simd_sum(sumf[row]);
631
  if (tiisg == 0 && first_row + row < ne01) {
632
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
633
  }
634
  }
635
  }
636
 
637
+ kernel void kernel_mul_mv_q4_0_f32(
638
  device const void * src0,
639
  device const float * src1,
640
  device float * dst,
 
647
  constant int64_t & ne1[[buffer(16)]],
648
  constant uint & gqa[[buffer(17)]],
649
  uint3 tgpig[[threadgroup_position_in_grid]],
650
+ uint tiisg[[thread_index_in_simdgroup]],
651
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
653
  }
654
 
655
+ kernel void kernel_mul_mv_q4_1_f32(
656
  device const void * src0,
657
  device const float * src1,
658
  device float * dst,
 
670
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
671
  }
672
 
673
+ kernel void kernel_mul_mv_q5_0_f32(
674
+ device const void * src0,
675
+ device const float * src1,
676
+ device float * dst,
677
+ constant int64_t & ne00,
678
+ constant int64_t & ne01[[buffer(4)]],
679
+ constant int64_t & ne02[[buffer(5)]],
680
+ constant int64_t & ne10[[buffer(9)]],
681
+ constant int64_t & ne12[[buffer(11)]],
682
+ constant int64_t & ne0[[buffer(15)]],
683
+ constant int64_t & ne1[[buffer(16)]],
684
+ constant uint & gqa[[buffer(17)]],
685
+ uint3 tgpig[[threadgroup_position_in_grid]],
686
+ uint tiisg[[thread_index_in_simdgroup]],
687
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
689
+ }
690
+
691
+ kernel void kernel_mul_mv_q5_1_f32(
692
+ device const void * src0,
693
+ device const float * src1,
694
+ device float * dst,
695
+ constant int64_t & ne00,
696
+ constant int64_t & ne01[[buffer(4)]],
697
+ constant int64_t & ne02[[buffer(5)]],
698
+ constant int64_t & ne10[[buffer(9)]],
699
+ constant int64_t & ne12[[buffer(11)]],
700
+ constant int64_t & ne0[[buffer(15)]],
701
+ constant int64_t & ne1[[buffer(16)]],
702
+ constant uint & gqa[[buffer(17)]],
703
+ uint3 tgpig[[threadgroup_position_in_grid]],
704
+ uint tiisg[[thread_index_in_simdgroup]],
705
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
707
+ }
708
+
709
+
710
  #define NB_Q8_0 8
711
 
712
+ kernel void kernel_mul_mv_q8_0_f32(
713
  device const void * src0,
714
  device const float * src1,
715
  device float * dst,
 
773
 
774
  #define N_F32_F32 4
775
 
776
+ kernel void kernel_mul_mv_f32_f32(
777
  device const char * src0,
778
  device const char * src1,
779
  device float * dst,
 
844
  }
845
  }
846
 
847
+ kernel void kernel_mul_mv_f16_f32_1row(
848
  device const char * src0,
849
  device const char * src1,
850
  device float * dst,
 
863
  constant int64_t & ne0,
864
  constant int64_t & ne1,
865
  uint3 tgpig[[threadgroup_position_in_grid]],
866
+ uint tiisg[[thread_index_in_simdgroup]]) {
867
 
868
  const int64_t r0 = tgpig.x;
869
  const int64_t r1 = tgpig.y;
 
898
 
899
  #define N_F16_F32 4
900
 
901
+ kernel void kernel_mul_mv_f16_f32(
902
  device const char * src0,
903
  device const char * src1,
904
  device float * dst,
 
970
  }
971
 
972
  // Assumes row size (ne00) is a multiple of 4
973
+ kernel void kernel_mul_mv_f16_f32_l4(
974
  device const char * src0,
975
  device const char * src1,
976
  device float * dst,
 
1031
  constant uint64_t & nb1,
1032
  constant uint64_t & nb2,
1033
  constant uint64_t & nb3,
1034
+ constant float & m0,
1035
+ constant float & m1,
1036
+ constant int & n_heads_log2_floor,
1037
  uint3 tgpig[[threadgroup_position_in_grid]],
1038
  uint3 tpitg[[thread_position_in_threadgroup]],
1039
  uint3 ntg[[threads_per_threadgroup]]) {
 
1049
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1050
 
1051
  device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1052
+ float m_k;
1053
+ if (i2 < n_heads_log2_floor) {
1054
+ m_k = pow(m0, i2 + 1);
1055
+ } else {
1056
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1057
+ }
1058
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1059
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1060
  dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1061
  }
1062
  }
1063
 
1064
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1065
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
1066
+ return 1.0f - min(1.0f, max(0.0f, y));
1067
+ }
1068
+
1069
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1070
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1071
+ static void rope_yarn(
1072
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1073
+ thread float * cos_theta, thread float * sin_theta
1074
+ ) {
1075
+ // Get n-d rotational scaling corrected for extrapolation
1076
+ float theta_interp = freq_scale * theta_extrap;
1077
+ float theta = theta_interp;
1078
+ if (ext_factor != 0.0f) {
1079
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1080
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1081
+
1082
+ // Get n-d magnitude scaling corrected for interpolation
1083
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1084
+ }
1085
+ *cos_theta = cos(theta) * mscale;
1086
+ *sin_theta = sin(theta) * mscale;
1087
+ }
1088
+
1089
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1090
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1091
+ static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1092
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1093
+ }
1094
+
1095
+ static void rope_yarn_corr_dims(
1096
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1097
+ ) {
1098
+ // start and end correction dims
1099
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1100
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1101
+ }
1102
+
1103
+ typedef void (rope_t)(
1104
+ device const void * src0,
1105
+ device const int32_t * src1,
1106
+ device float * dst,
1107
+ constant int64_t & ne00,
1108
+ constant int64_t & ne01,
1109
+ constant int64_t & ne02,
1110
+ constant int64_t & ne03,
1111
+ constant uint64_t & nb00,
1112
+ constant uint64_t & nb01,
1113
+ constant uint64_t & nb02,
1114
+ constant uint64_t & nb03,
1115
+ constant int64_t & ne0,
1116
+ constant int64_t & ne1,
1117
+ constant int64_t & ne2,
1118
+ constant int64_t & ne3,
1119
+ constant uint64_t & nb0,
1120
+ constant uint64_t & nb1,
1121
+ constant uint64_t & nb2,
1122
+ constant uint64_t & nb3,
1123
+ constant int & n_past,
1124
+ constant int & n_dims,
1125
+ constant int & mode,
1126
+ constant int & n_orig_ctx,
1127
+ constant float & freq_base,
1128
+ constant float & freq_scale,
1129
+ constant float & ext_factor,
1130
+ constant float & attn_factor,
1131
+ constant float & beta_fast,
1132
+ constant float & beta_slow,
1133
+ uint tiitg[[thread_index_in_threadgroup]],
1134
+ uint3 tptg[[threads_per_threadgroup]],
1135
+ uint3 tgpig[[threadgroup_position_in_grid]]);
1136
+
1137
+ template<typename T>
1138
  kernel void kernel_rope(
1139
+ device const void * src0,
1140
+ device const int32_t * src1,
1141
+ device float * dst,
1142
+ constant int64_t & ne00,
1143
+ constant int64_t & ne01,
1144
+ constant int64_t & ne02,
1145
+ constant int64_t & ne03,
1146
+ constant uint64_t & nb00,
1147
+ constant uint64_t & nb01,
1148
+ constant uint64_t & nb02,
1149
+ constant uint64_t & nb03,
1150
+ constant int64_t & ne0,
1151
+ constant int64_t & ne1,
1152
+ constant int64_t & ne2,
1153
+ constant int64_t & ne3,
1154
+ constant uint64_t & nb0,
1155
+ constant uint64_t & nb1,
1156
+ constant uint64_t & nb2,
1157
+ constant uint64_t & nb3,
1158
+ constant int & n_past,
1159
+ constant int & n_dims,
1160
+ constant int & mode,
1161
+ constant int & n_orig_ctx,
1162
+ constant float & freq_base,
1163
+ constant float & freq_scale,
1164
+ constant float & ext_factor,
1165
+ constant float & attn_factor,
1166
+ constant float & beta_fast,
1167
+ constant float & beta_slow,
1168
  uint tiitg[[thread_index_in_threadgroup]],
1169
  uint3 tptg[[threads_per_threadgroup]],
1170
  uint3 tgpig[[threadgroup_position_in_grid]]) {
 
1174
 
1175
  const bool is_neox = mode & 2;
1176
 
1177
+ float corr_dims[2];
1178
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1179
 
1180
+ device const int32_t * pos = src1;
1181
+
1182
+ const int64_t p = pos[i2];
1183
+
1184
+ const float theta_0 = (float)p;
1185
  const float inv_ndims = -1.f/n_dims;
1186
 
1187
  if (!is_neox) {
1188
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1189
 
1190
  const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1191
+ float cos_theta, sin_theta;
1192
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1193
 
1194
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1195
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1196
 
1197
+ const T x0 = src[0];
1198
+ const T x1 = src[1];
1199
 
1200
  dst_data[0] = x0*cos_theta - x1*sin_theta;
1201
  dst_data[1] = x0*sin_theta + x1*cos_theta;
 
1204
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1205
  for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1206
 
1207
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1208
+ const float cur_rot = inv_ndims*ic - ib;
1209
+
1210
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1211
+ float cos_theta, sin_theta;
1212
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1213
 
1214
  const int64_t i0 = ib*n_dims + ic/2;
1215
 
1216
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1217
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1218
 
1219
  const float x0 = src[0];
1220
  const float x1 = src[n_dims/2];
 
1226
  }
1227
  }
1228
 
1229
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1231
+
1232
  kernel void kernel_cpy_f16_f16(
1233
  device const half * src0,
1234
  device half * dst,
 
1354
  }
1355
  }
1356
 
1357
+ kernel void kernel_concat(
1358
+ device const char * src0,
1359
+ device const char * src1,
1360
+ device char * dst,
1361
+ constant int64_t & ne00,
1362
+ constant int64_t & ne01,
1363
+ constant int64_t & ne02,
1364
+ constant int64_t & ne03,
1365
+ constant uint64_t & nb00,
1366
+ constant uint64_t & nb01,
1367
+ constant uint64_t & nb02,
1368
+ constant uint64_t & nb03,
1369
+ constant int64_t & ne10,
1370
+ constant int64_t & ne11,
1371
+ constant int64_t & ne12,
1372
+ constant int64_t & ne13,
1373
+ constant uint64_t & nb10,
1374
+ constant uint64_t & nb11,
1375
+ constant uint64_t & nb12,
1376
+ constant uint64_t & nb13,
1377
+ constant int64_t & ne0,
1378
+ constant int64_t & ne1,
1379
+ constant int64_t & ne2,
1380
+ constant int64_t & ne3,
1381
+ constant uint64_t & nb0,
1382
+ constant uint64_t & nb1,
1383
+ constant uint64_t & nb2,
1384
+ constant uint64_t & nb3,
1385
+ uint3 tgpig[[threadgroup_position_in_grid]],
1386
+ uint3 tpitg[[thread_position_in_threadgroup]],
1387
+ uint3 ntg[[threads_per_threadgroup]]) {
1388
+
1389
+ const int64_t i03 = tgpig.z;
1390
+ const int64_t i02 = tgpig.y;
1391
+ const int64_t i01 = tgpig.x;
1392
+
1393
+ const int64_t i13 = i03 % ne13;
1394
+ const int64_t i12 = i02 % ne12;
1395
+ const int64_t i11 = i01 % ne11;
1396
+
1397
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1398
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1399
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1400
+
1401
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1402
+ if (i02 < ne02) {
1403
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1404
+ src0_ptr += ntg.x*nb00;
1405
+ } else {
1406
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1407
+ src1_ptr += ntg.x*nb10;
1408
+ }
1409
+ dst_ptr += ntg.x*nb0;
1410
+ }
1411
+ }
1412
+
1413
  //============================================ k-quants ======================================================
1414
 
1415
  #ifndef QK_K
 
1502
 
1503
  //====================================== dot products =========================
1504
 
1505
+ kernel void kernel_mul_mv_q2_K_f32(
1506
  device const void * src0,
1507
  device const float * src1,
1508
  device float * dst,
 
1646
  }
1647
 
1648
  #if QK_K == 256
1649
+ kernel void kernel_mul_mv_q3_K_f32(
1650
  device const void * src0,
1651
  device const float * src1,
1652
  device float * dst,
 
1675
 
1676
  float yl[32];
1677
 
1678
+ //const uint16_t kmask1 = 0x3030;
1679
+ //const uint16_t kmask2 = 0x0f0f;
1680
 
1681
  const int tid = tiisg/4;
1682
  const int ix = tiisg%4;
 
1798
  }
1799
  }
1800
  #else
1801
+ kernel void kernel_mul_mv_q3_K_f32(
1802
  device const void * src0,
1803
  device const float * src1,
1804
  device float * dst,
 
1869
  #endif
1870
 
1871
  #if QK_K == 256
1872
+ kernel void kernel_mul_mv_q4_K_f32(
1873
  device const void * src0,
1874
  device const float * src1,
1875
  device float * dst,
 
1975
  }
1976
  }
1977
  #else
1978
+ kernel void kernel_mul_mv_q4_K_f32(
1979
  device const void * src0,
1980
  device const float * src1,
1981
  device float * dst,
 
2064
  }
2065
  #endif
2066
 
2067
+ kernel void kernel_mul_mv_q5_K_f32(
2068
  device const void * src0,
2069
  device const float * src1,
2070
  device float * dst,
 
2237
 
2238
  }
2239
 
2240
+ kernel void kernel_mul_mv_q6_K_f32(
2241
  device const void * src0,
2242
  device const float * src1,
2243
  device float * dst,
 
2386
  }
2387
  }
2388
 
2389
+ template <typename type4x4>
2390
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2391
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2392
+ const float d = xb->d;
2393
+ const float md = -16.h * xb->d;
2394
+ const ushort mask = il ? 0x00F0 : 0x000F;
2395
+
2396
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2397
+
2398
+ const int x_mv = il ? 4 : 0;
2399
+
2400
+ const int gh_mv = il ? 12 : 0;
2401
+ const int gh_bk = il ? 0 : 4;
2402
+
2403
+ for (int i = 0; i < 8; i++) {
2404
+ // extract the 5-th bits for x0 and x1
2405
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2406
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2407
+
2408
+ // combine the 4-bits from qs with the 5th bit
2409
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2410
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2411
+
2412
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2413
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2414
+ }
2415
+ }
2416
+
2417
+ template <typename type4x4>
2418
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2419
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2420
+ const float d = xb->d;
2421
+ const float m = xb->m;
2422
+ const ushort mask = il ? 0x00F0 : 0x000F;
2423
+
2424
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2425
+
2426
+ const int x_mv = il ? 4 : 0;
2427
+
2428
+ const int gh_mv = il ? 12 : 0;
2429
+ const int gh_bk = il ? 0 : 4;
2430
+
2431
+ for (int i = 0; i < 8; i++) {
2432
+ // extract the 5-th bits for x0 and x1
2433
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2434
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2435
+
2436
+ // combine the 4-bits from qs with the 5th bit
2437
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2438
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2439
+
2440
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2441
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2442
+ }
2443
+ }
2444
+
2445
  template <typename type4x4>
2446
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2447
  device const int8_t * qs = ((device const int8_t *)xb->qs);
 
2631
  }
2632
 
2633
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2634
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2635
  #define BLOCK_SIZE_K 32
2636
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2637
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
 
2668
  const uint r0 = tgpig.y;
2669
  const uint r1 = tgpig.x;
2670
  const uint im = tgpig.z;
2671
+
2672
  // if this block is of 64x32 shape or smaller
2673
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2674
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2675
+
2676
  // a thread shouldn't load data outside of the matrix
2677
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2678
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
 
2696
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2697
 
2698
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2699
+ // load data and store to threadgroup memory
2700
  half4x4 temp_a;
2701
  dequantize_func(x, il, temp_a);
2702
  threadgroup_barrier(mem_flags::mem_threadgroup);
2703
+
2704
  #pragma unroll(16)
2705
  for (int i = 0; i < 16; i++) {
2706
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2707
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2708
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2709
  }
2710
+
2711
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2712
+
2713
  il = (il + 2 < nl) ? il + 2 : il % 2;
2714
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2715
  y += BLOCK_SIZE_K;
2716
 
2717
  threadgroup_barrier(mem_flags::mem_threadgroup);
2718
+
2719
+ // load matrices from threadgroup memory and conduct outer products
2720
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2721
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2722
+
2723
  #pragma unroll(4)
2724
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2725
  #pragma unroll(4)
 
2734
 
2735
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2736
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2737
+
2738
  #pragma unroll(8)
2739
  for (int i = 0; i < 8; i++){
2740
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
 
2743
  }
2744
 
2745
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2746
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2747
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2748
  for (int i = 0; i < 8; i++) {
2749
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2750
  }
2751
  } else {
2752
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2753
  threadgroup_barrier(mem_flags::mem_threadgroup);
2754
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2755
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2756
  for (int i = 0; i < 8; i++) {
2757
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2758
  }
2759
 
2760
  threadgroup_barrier(mem_flags::mem_threadgroup);
2761
+
2762
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2763
+ if (sgitg == 0) {
2764
  for (int i = 0; i < n_rows; i++) {
2765
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2766
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2767
  }
2768
  }
 
2783
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2784
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2785
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2786
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2787
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2788
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2789
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2790
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
 
2813
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2814
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2815
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2816
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2817
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2818
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2819
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2820
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
ggml-opencl.cpp CHANGED
@@ -19,7 +19,7 @@
19
  #pragma warning(disable: 4244 4267) // possible loss of data
20
  #endif
21
 
22
- #define CL_DMMV_BLOCK_SIZE 32
23
 
24
  #ifndef K_QUANTS_PER_ITERATION
25
  #define K_QUANTS_PER_ITERATION 1
@@ -202,14 +202,14 @@ inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8
202
 
203
  __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
204
  {
205
- const int i = get_group_id(0);
206
  const int tid = get_local_id(0);
207
  const int n = tid / 32;
208
  const int l = tid - 32 * n;
209
  const int is = 8 * n + l / 16;
210
 
211
  const uint8_t q = x[i].qs[32 * n + l];
212
- __global float *y = yy + i * QK_K + 128 * n;
213
 
214
  const float dall = vload_half(0, &x[i].d);
215
  const float dmin = vload_half(0, &x[i].dmin);
@@ -223,7 +223,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
223
  __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
224
  {
225
  int r = get_local_id(0) / 4;
226
- int i = get_group_id(0);
227
  int tid = r / 2;
228
  int is0 = r % 2;
229
  int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
@@ -241,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
241
  float d_all = vload_half(0, &x[i].d);
242
  float dl = d_all * (us - 32);
243
 
244
- __global float *y = yy + i * QK_K + 128 * n + 32 * j;
245
  const __global uint8_t *q = x[i].qs + 32 * n;
246
  const __global uint8_t *hm = x[i].hmask;
247
 
@@ -251,14 +251,14 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
251
 
252
  __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
253
  {
254
- const int i = get_group_id(0);
255
  const int tid = get_local_id(0);
256
  const int il = tid / 8;
257
  const int ir = tid % 8;
258
  const int is = 2 * il;
259
  const int n = 4;
260
 
261
- __global float *y = yy + i * QK_K + 64 * il + n * ir;
262
 
263
  const float dall = vload_half(0, &x[i].d);
264
  const float dmin = vload_half(0, &x[i].dmin);
@@ -281,13 +281,13 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
281
 
282
  __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
283
  {
284
- const int i = get_group_id(0);
285
  const int tid = get_local_id(0);
286
  const int il = tid / 16;
287
  const int ir = tid % 16;
288
  const int is = 2 * il;
289
 
290
- __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
291
 
292
  const float dall = vload_half(0, &x[i].d);
293
  const float dmin = vload_half(0, &x[i].dmin);
@@ -313,13 +313,13 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
313
 
314
  __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
315
  {
316
- const int i = get_group_id(0);
317
  const int tid = get_local_id(0);
318
  const int ip = tid / 32;
319
  const int il = tid - 32 * ip;
320
  const int is = 8 * ip + il / 16;
321
 
322
- __global float *y = yy + i * QK_K + 128 * ip + il;
323
 
324
  const float d = vload_half(0, &x[i].d);
325
 
@@ -338,7 +338,7 @@ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx,
338
  const int row = get_group_id(0);
339
 
340
  const int num_blocks_per_row = ncols / QK_K;
341
- const int ib0 = row*num_blocks_per_row;
342
 
343
  __global const struct block_q2_K * x = xx + ib0;
344
 
@@ -413,7 +413,7 @@ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx,
413
  const int row = get_group_id(0);
414
 
415
  const int num_blocks_per_row = ncols / QK_K;
416
- const int ib0 = row*num_blocks_per_row;
417
 
418
  __global const struct block_q3_K * x = xx + ib0;
419
 
@@ -489,7 +489,7 @@ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx,
489
 
490
  const int row = get_group_id(0);
491
  const int num_blocks_per_row = ncols / QK_K;
492
- const int ib0 = row*num_blocks_per_row;
493
 
494
  const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
495
  const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
@@ -562,7 +562,7 @@ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx,
562
 
563
  const int row = get_group_id(0);
564
  const int num_blocks_per_row = ncols / QK_K;
565
- const int ib0 = row*num_blocks_per_row;
566
 
567
  const int tid = get_local_id(0)/2; // 0...15
568
  const int ix = get_local_id(0)%2;
@@ -641,7 +641,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
641
  const int row = get_group_id(0);
642
 
643
  const int num_blocks_per_row = ncols / QK_K;
644
- const int ib0 = row*num_blocks_per_row;
645
 
646
  __global const struct block_q6_K * x = xx + ib0;
647
 
@@ -730,7 +730,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
730
  const uint qk = QUANT_K;
731
  const uint qr = QUANT_R;
732
 
733
- const int ib = i/qk; // block index
734
  const int iqs = (i%qk)/qr; // quant index
735
  const int iybs = i - i%qk; // y block start index
736
  const int y_offset = qr == 1 ? 1 : qk/2;
@@ -745,19 +745,21 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
745
 
746
  std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
747
  __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
748
- const int block_size = get_local_size(0);
749
  const int row = get_group_id(0);
750
  const int tid = get_local_id(0);
751
 
752
  const uint qk = QUANT_K;
753
  const uint qr = QUANT_R;
754
 
 
755
  const int y_offset = qr == 1 ? 1 : qk/2;
756
 
 
 
757
  tmp[tid] = 0;
758
 
759
- for (int i = 0; i < ncols/block_size; i += 2) {
760
- const int col = i*block_size + 2*tid;
761
  const int ib = (row*ncols + col)/qk; // block index
762
  const int iqs = (col%qk)/qr; // quant index
763
  const int iybs = col - col%qk; // y block start index
@@ -773,7 +775,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
773
 
774
  // sum up partial sums and write back result
775
  barrier(CLK_LOCAL_MEM_FENCE);
776
- for (int s=block_size/2; s>0; s>>=1) {
777
  if (tid < s) {
778
  tmp[tid] += tmp[tid + s];
779
  }
@@ -847,7 +849,7 @@ std::array<std::string, 2> mul_str_values = {
847
  "mul_f32", "float"
848
  };
849
 
850
- std::string& replace(std::string& s, const std::string& from, const std::string& to) {
851
  size_t pos = 0;
852
  while ((pos = s.find(from, pos)) != std::string::npos) {
853
  s.replace(pos, from.length(), to);
@@ -856,7 +858,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
856
  return s;
857
  }
858
 
859
- std::string generate_kernels() {
860
  std::stringstream src;
861
  src << program_source << '\n';
862
  src << k_quants_source << '\n';
@@ -1349,30 +1351,42 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o
1349
  const enum ggml_type type = src->type;
1350
  const size_t ts = ggml_type_size(type);
1351
  const size_t bs = ggml_blck_size(type);
 
1352
 
1353
- const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
1354
- if (nb0 == ts && nb1 == ts*ne0/bs) {
1355
- err = clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*nb1, x, 0, NULL, ev);
1356
- return err;
1357
  }
1358
  if (nb0 == ts) {
1359
  const size_t buffer_origin[3] = { offset, 0, 0 };
1360
  const size_t host_origin[3] = { 0, 0, 0 };
1361
- const size_t region[3] = { ts*ne0/bs, ne1, 1 };
1362
- err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts*ne0/bs, 0, nb1, 0, x, 0, NULL, ev);
1363
- return err;
1364
  }
 
 
1365
  for (uint64_t i1 = 0; i1 < ne1; i1++) {
1366
  // pretend the row is a matrix with cols=1
1367
- const size_t buffer_origin[3] = { offset, i1, 0 };
1368
  const size_t host_origin[3] = { 0, 0, 0 };
1369
- const size_t region[3] = { ts/bs, ne0, 1 };
1370
- err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, 0, 0, nb0, 0, ((const char *)x) + i1*nb0, 0, NULL, ev);
 
 
 
 
 
1371
  if (err != CL_SUCCESS) {
1372
- break;
 
 
 
1373
  }
1374
  }
1375
- return err;
 
 
 
1376
  }
1377
 
1378
  static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -1381,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
1381
  const int64_t ne01 = src0->ne[1];
1382
  const int64_t ne02 = src0->ne[2];
1383
  const int64_t ne03 = src0->ne[3];
1384
- const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
1385
  const int64_t ne10 = src1->ne[0];
1386
  const int64_t ne11 = src1->ne[1];
1387
  const int64_t ne12 = src1->ne[2];
1388
  const int64_t ne13 = src1->ne[3];
1389
- const int64_t nb10 = src1->nb[0];
1390
  const int nb2 = dst->nb[2];
1391
  const int nb3 = dst->nb[3];
1392
  size_t x_size;
1393
  size_t d_size;
1394
 
1395
- cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
1396
  cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
1397
- cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
1398
 
1399
 
1400
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1401
  for (int64_t i02 = 0; i02 < ne02; i02++) {
1402
- const int i0 = i03*ne02 + i02;
1403
-
1404
  cl_event ev;
1405
 
1406
  // copy src0 to device
1407
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
1408
-
1409
- if (nb10 == sizeof(float)) {
1410
- // Contiguous, avoid overhead from queueing many kernel runs
1411
- const int64_t i13 = i03%ne13;
1412
- const int64_t i12 = i02%ne12;
1413
- const int i1 = i13*ne12*ne11 + i12*ne11;
1414
-
1415
- cl_int x_offset = 0;
1416
- cl_int y_offset = i1*ne10;
1417
- cl_int d_offset = 0;
1418
-
1419
- size_t global = ne00 * ne01;
1420
- cl_int ky = ne10;
1421
- CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1422
- CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1423
- CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1424
- CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1425
- CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1426
- CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1427
- CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1428
- CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1429
- } else {
1430
- for (int64_t i01 = 0; i01 < ne01; i01++) {
1431
- const int64_t i13 = i03%ne13;
1432
- const int64_t i12 = i02%ne12;
1433
- const int64_t i11 = i01%ne11;
1434
- const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
1435
-
1436
- cl_int x_offset = i01*ne00;
1437
- cl_int y_offset = i1*ne10;
1438
- cl_int d_offset = i01*ne00;
1439
 
1440
- // compute
1441
- size_t global = ne00;
1442
- cl_int ky = ne10;
1443
- CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1444
- CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1445
- CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1446
- CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1447
- CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1448
- CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1449
- CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1450
- CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1451
- }
1452
- }
 
 
 
 
 
 
1453
 
1454
  CL_CHECK(clReleaseEvent(ev));
1455
  CL_CHECK(clFinish(queue));
@@ -1476,10 +1461,15 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
1476
 
1477
  const int64_t ne10 = src1->ne[0];
1478
  const int64_t ne11 = src1->ne[1];
 
 
1479
 
1480
  const int nb2 = dst->nb[2];
1481
  const int nb3 = dst->nb[3];
1482
 
 
 
 
1483
  const float alpha = 1.0f;
1484
  const float beta = 0.0f;
1485
  const int x_ne = ne01 * ne00;
@@ -1498,35 +1488,46 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
1498
  cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
1499
  cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
1500
 
 
 
1501
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1502
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1503
- // copy data to device
1504
- if (src0->backend != GGML_BACKEND_GPU) {
1505
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1506
- }
1507
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
 
 
 
1508
 
1509
- CL_CHECK(clFinish(queue));
 
 
1510
 
1511
- // compute
1512
- cl_event ev_sgemm;
1513
- clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
1514
- clblast::Transpose::kYes, clblast::Transpose::kNo,
1515
- ne01, ne11, ne10,
1516
- alpha,
1517
- d_X, 0, ne00,
1518
- d_Y, 0, ne10,
1519
- beta,
1520
- d_D, 0, ne01,
1521
- &queue, &ev_sgemm);
1522
-
1523
- if (status != clblast::StatusCode::kSuccess) {
1524
- GGML_ASSERT(false);
1525
- }
1526
 
1527
- // copy dst to host
1528
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1529
- CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1530
  }
1531
  }
1532
 
@@ -1537,7 +1538,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
1537
  ggml_cl_pool_free(d_D, d_size);
1538
  }
1539
 
1540
- static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
1541
  GGML_ASSERT(fp16_support);
1542
 
1543
  const int64_t ne00 = src0->ne[0];
@@ -1547,6 +1548,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1547
 
1548
  const int64_t ne10 = src1->ne[0];
1549
  const int64_t ne11 = src1->ne[1];
 
 
1550
 
1551
  const int nb10 = src1->nb[0];
1552
  const int nb11 = src1->nb[1];
@@ -1556,12 +1559,19 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1556
  const int nb2 = dst->nb[2];
1557
  const int nb3 = dst->nb[3];
1558
 
 
 
 
1559
  const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f);
1560
  const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f);
1561
  const int x_ne = ne01 * ne00;
1562
  const int y_ne = ne11 * ne10;
1563
  const int d_ne = ne11 * ne01;
1564
 
 
 
 
 
1565
  size_t x_size;
1566
  size_t y_size;
1567
  size_t d_size;
@@ -1577,63 +1587,71 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1577
  bool src1_cont_rows = nb10 == sizeof(float);
1578
  bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
1579
 
1580
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1581
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1582
- // copy src0 to device
1583
- if (src0->backend != GGML_BACKEND_GPU) {
1584
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1585
- }
1586
 
1587
- // convert src1 to fp16
1588
- // TODO: use multiple threads
1589
- ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
1590
- char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
1591
- if (src1_cont_rows) {
1592
- if (src1_cont_cols) {
1593
- ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
 
 
1594
  }
1595
- else {
1596
- for (int64_t i01 = 0; i01 < ne11; i01++) {
1597
- ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
 
 
 
 
 
 
 
 
 
 
 
1598
  }
1599
- }
1600
- }
1601
- else {
1602
- for (int64_t i01 = 0; i01 < ne11; i01++) {
1603
- for (int64_t i00 = 0; i00 < ne10; i00++) {
1604
- // very slow due to no inlining
1605
- tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
1606
  }
1607
- }
1608
- }
1609
 
1610
- // copy src1 to device
1611
- CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL));
1612
 
1613
- CL_CHECK(clFinish(queue));
1614
 
1615
- // compute
1616
- cl_event ev_sgemm;
1617
- clblast::StatusCode status = clblast::Gemm<cl_half>(clblast::Layout::kColMajor,
1618
- clblast::Transpose::kYes, clblast::Transpose::kNo,
1619
- ne01, ne11, ne10,
1620
- alpha,
1621
- d_X, 0, ne00,
1622
- d_Y, 0, ne10,
1623
- beta,
1624
- d_D, 0, ne01,
1625
- &queue, &ev_sgemm);
1626
-
1627
- if (status != clblast::StatusCode::kSuccess) {
1628
- GGML_ASSERT(false);
1629
- }
1630
 
1631
- // copy dst to host, then convert to float
1632
- CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
1633
 
1634
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1635
 
1636
- ggml_fp16_to_fp32_row(tmp, d, d_ne);
 
 
1637
  }
1638
  }
1639
 
@@ -1652,18 +1670,24 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
1652
 
1653
  const int64_t ne10 = src1->ne[0];
1654
  const int64_t ne11 = src1->ne[1];
 
 
1655
 
1656
  const int nb2 = dst->nb[2];
1657
  const int nb3 = dst->nb[3];
1658
  const ggml_type type = src0->type;
1659
- const bool mul_mat_vec = ne11 == 1;
 
 
 
1660
 
1661
  const float alpha = 1.0f;
1662
  const float beta = 0.0f;
1663
  const int x_ne = ne01 * ne00;
1664
  const int y_ne = ne11 * ne10;
1665
  const int d_ne = ne11 * ne01;
1666
- const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
 
1667
 
1668
  size_t x_size;
1669
  size_t y_size;
@@ -1685,78 +1709,86 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
1685
  GGML_ASSERT(to_fp32_cl != nullptr);
1686
 
1687
  const size_t global_denom = ggml_cl_global_denom(type);
1688
- const size_t local = ggml_cl_local_size(type);
1689
 
1690
  size_t ev_idx = 0;
1691
  std::vector<cl_event> events;
1692
 
1693
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1694
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1695
- // copy src0 to device if necessary
1696
- if (src0->backend == GGML_BACKEND_CPU) {
1697
- events.emplace_back();
1698
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
1699
- } else if (src0->backend == GGML_BACKEND_GPU) {
1700
- d_Q = (cl_mem) src0->extra;
1701
- } else {
1702
- GGML_ASSERT(false);
1703
- }
1704
- if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
1705
- // copy src1 to device
1706
- events.emplace_back();
1707
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, events.data() + ev_idx++));
1708
-
1709
- // compute
1710
- const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
1711
- const size_t local = CL_DMMV_BLOCK_SIZE;
1712
- const cl_int ncols = ne00;
1713
- events.emplace_back();
1714
- CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
1715
- CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
1716
- CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
1717
- CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
1718
- CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
1719
- CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
1720
- } else { // general dequantization kernel + CLBlast matrix matrix multiplication
1721
- // convert src0 to fp32 on device
1722
- const size_t global = x_ne / global_denom;
1723
- CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
1724
- CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
1725
- CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
1726
-
1727
- // copy src1 to device
1728
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
1729
-
1730
- events.emplace_back();
1731
-
1732
- // wait for conversion
1733
- CL_CHECK(clFinish(queue));
1734
-
1735
- // compute
1736
- clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
1737
- clblast::Transpose::kYes, clblast::Transpose::kNo,
1738
- ne01, ne11, ne10,
1739
- alpha,
1740
- d_X, 0, ne00,
1741
- d_Y, 0, ne10,
1742
- beta,
1743
- d_D, 0, ne01,
1744
- &queue, events.data() + ev_idx++);
1745
-
1746
- if (status != clblast::StatusCode::kSuccess) {
1747
  GGML_ASSERT(false);
1748
  }
1749
- }
1750
 
1751
- // copy dst to host
1752
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1753
- CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
1754
- for (auto *event : events) {
1755
- clReleaseEvent(event);
1756
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1757
 
1758
- ev_idx = 0;
1759
- events.clear();
 
 
 
 
 
 
 
 
 
1760
  }
1761
  }
1762
 
@@ -1788,7 +1820,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
1788
  return false;
1789
  }
1790
 
1791
- bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1792
  // If device doesn't support FP16
1793
  if (!fp16_support) {
1794
  return false;
@@ -1831,8 +1863,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
1831
  }
1832
 
1833
  size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1834
- if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
1835
- return ggml_nelements(src1) * sizeof(ggml_fp16_t);
1836
  }
1837
  return 0;
1838
  }
@@ -1844,17 +1876,19 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
1844
  const int64_t ne3 = tensor->ne[3];
1845
 
1846
  const ggml_type type = tensor->type;
1847
- const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
 
1848
 
1849
  size_t q_size;
1850
  cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size);
1851
 
1852
  tensor->data = data;
1853
  // copy tensor to device
 
1854
  for (int64_t i3 = 0; i3 < ne3; i3++) {
1855
  for (int64_t i2 = 0; i2 < ne2; i2++) {
1856
- int i = i3*ne2 + i2;
1857
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, i*ne0*ne1, tensor, i3, i2, NULL));
1858
  }
1859
  }
1860
 
 
19
  #pragma warning(disable: 4244 4267) // possible loss of data
20
  #endif
21
 
22
+ #define CL_DMMV_LOCAL_SIZE 32
23
 
24
  #ifndef K_QUANTS_PER_ITERATION
25
  #define K_QUANTS_PER_ITERATION 1
 
202
 
203
  __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
204
  {
205
+ const int i = get_group_id(0) + get_global_offset(0);
206
  const int tid = get_local_id(0);
207
  const int n = tid / 32;
208
  const int l = tid - 32 * n;
209
  const int is = 8 * n + l / 16;
210
 
211
  const uint8_t q = x[i].qs[32 * n + l];
212
+ __global float *y = yy + get_group_id(0) * QK_K + 128 * n;
213
 
214
  const float dall = vload_half(0, &x[i].d);
215
  const float dmin = vload_half(0, &x[i].dmin);
 
223
  __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
224
  {
225
  int r = get_local_id(0) / 4;
226
+ int i = get_group_id(0) + get_global_offset(0);
227
  int tid = r / 2;
228
  int is0 = r % 2;
229
  int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
 
241
  float d_all = vload_half(0, &x[i].d);
242
  float dl = d_all * (us - 32);
243
 
244
+ __global float *y = yy + get_group_id(0) * QK_K + 128 * n + 32 * j;
245
  const __global uint8_t *q = x[i].qs + 32 * n;
246
  const __global uint8_t *hm = x[i].hmask;
247
 
 
251
 
252
  __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
253
  {
254
+ const int i = get_group_id(0) + get_global_offset(0);
255
  const int tid = get_local_id(0);
256
  const int il = tid / 8;
257
  const int ir = tid % 8;
258
  const int is = 2 * il;
259
  const int n = 4;
260
 
261
+ __global float *y = yy + get_group_id(0) * QK_K + 64 * il + n * ir;
262
 
263
  const float dall = vload_half(0, &x[i].d);
264
  const float dmin = vload_half(0, &x[i].dmin);
 
281
 
282
  __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
283
  {
284
+ const int i = get_group_id(0) + get_global_offset(0);
285
  const int tid = get_local_id(0);
286
  const int il = tid / 16;
287
  const int ir = tid % 16;
288
  const int is = 2 * il;
289
 
290
+ __global float *y = yy + get_group_id(0) * QK_K + 64 * il + 2 * ir;
291
 
292
  const float dall = vload_half(0, &x[i].d);
293
  const float dmin = vload_half(0, &x[i].dmin);
 
313
 
314
  __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
315
  {
316
+ const int i = get_group_id(0) + get_global_offset(0);
317
  const int tid = get_local_id(0);
318
  const int ip = tid / 32;
319
  const int il = tid - 32 * ip;
320
  const int is = 8 * ip + il / 16;
321
 
322
+ __global float *y = yy + get_group_id(0) * QK_K + 128 * ip + il;
323
 
324
  const float d = vload_half(0, &x[i].d);
325
 
 
338
  const int row = get_group_id(0);
339
 
340
  const int num_blocks_per_row = ncols / QK_K;
341
+ const int ib0 = row*num_blocks_per_row + get_global_offset(0);
342
 
343
  __global const struct block_q2_K * x = xx + ib0;
344
 
 
413
  const int row = get_group_id(0);
414
 
415
  const int num_blocks_per_row = ncols / QK_K;
416
+ const int ib0 = row*num_blocks_per_row + get_global_offset(0);
417
 
418
  __global const struct block_q3_K * x = xx + ib0;
419
 
 
489
 
490
  const int row = get_group_id(0);
491
  const int num_blocks_per_row = ncols / QK_K;
492
+ const int ib0 = row*num_blocks_per_row + get_global_offset(0);
493
 
494
  const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
495
  const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
 
562
 
563
  const int row = get_group_id(0);
564
  const int num_blocks_per_row = ncols / QK_K;
565
+ const int ib0 = row*num_blocks_per_row + get_global_offset(0);
566
 
567
  const int tid = get_local_id(0)/2; // 0...15
568
  const int ix = get_local_id(0)%2;
 
641
  const int row = get_group_id(0);
642
 
643
  const int num_blocks_per_row = ncols / QK_K;
644
+ const int ib0 = row*num_blocks_per_row + get_global_offset(0);
645
 
646
  __global const struct block_q6_K * x = xx + ib0;
647
 
 
730
  const uint qk = QUANT_K;
731
  const uint qr = QUANT_R;
732
 
733
+ const int ib = i/qk + get_global_offset(0); // block index
734
  const int iqs = (i%qk)/qr; // quant index
735
  const int iybs = i - i%qk; // y block start index
736
  const int y_offset = qr == 1 ? 1 : qk/2;
 
745
 
746
  std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
747
  __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
748
+ const int local_size = get_local_size(0);
749
  const int row = get_group_id(0);
750
  const int tid = get_local_id(0);
751
 
752
  const uint qk = QUANT_K;
753
  const uint qr = QUANT_R;
754
 
755
+ const int col_step = local_size * 2;
756
  const int y_offset = qr == 1 ? 1 : qk/2;
757
 
758
+ x += get_global_offset(0);
759
+
760
  tmp[tid] = 0;
761
 
762
+ for (int col = tid*2; col < ncols; col += col_step) {
 
763
  const int ib = (row*ncols + col)/qk; // block index
764
  const int iqs = (col%qk)/qr; // quant index
765
  const int iybs = col - col%qk; // y block start index
 
775
 
776
  // sum up partial sums and write back result
777
  barrier(CLK_LOCAL_MEM_FENCE);
778
+ for (int s=local_size/2; s>0; s>>=1) {
779
  if (tid < s) {
780
  tmp[tid] += tmp[tid + s];
781
  }
 
849
  "mul_f32", "float"
850
  };
851
 
852
+ static std::string& replace(std::string& s, const std::string& from, const std::string& to) {
853
  size_t pos = 0;
854
  while ((pos = s.find(from, pos)) != std::string::npos) {
855
  s.replace(pos, from.length(), to);
 
858
  return s;
859
  }
860
 
861
+ static std::string generate_kernels() {
862
  std::stringstream src;
863
  src << program_source << '\n';
864
  src << k_quants_source << '\n';
 
1351
  const enum ggml_type type = src->type;
1352
  const size_t ts = ggml_type_size(type);
1353
  const size_t bs = ggml_blck_size(type);
1354
+ const uint64_t row_size = ts*ne0/bs;
1355
 
1356
+ const char * x = (const char *) src->data + i2*nb2 + i3*nb3;
1357
+ if (nb0 == ts && nb1 == row_size) {
1358
+ return clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*row_size, x, 0, NULL, ev);
 
1359
  }
1360
  if (nb0 == ts) {
1361
  const size_t buffer_origin[3] = { offset, 0, 0 };
1362
  const size_t host_origin[3] = { 0, 0, 0 };
1363
+ const size_t region[3] = { row_size, ne1, 1 };
1364
+ return clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, row_size, 0, nb1, 0, x, 0, NULL, ev);
 
1365
  }
1366
+ std::vector<cl_event> events;
1367
+ if (ev && ne1>1) events.reserve(ne1-1);
1368
  for (uint64_t i1 = 0; i1 < ne1; i1++) {
1369
  // pretend the row is a matrix with cols=1
1370
+ const size_t buffer_origin[3] = { offset + i1*row_size, 0, 0 };
1371
  const size_t host_origin[3] = { 0, 0, 0 };
1372
+ const size_t region[3] = { ts, ne0/bs, 1 };
1373
+ // if an event is requested, make the last write wait for all previous writes to complete
1374
+ if (ev && i1) {
1375
+ events.push_back(*ev);
1376
+ }
1377
+ cl_uint nevents = i1 == ne1-1 ? events.size() : 0U;
1378
+ err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts, 0, nb0, 0, x + i1*nb1, nevents, nevents ? events.data() : nullptr, ev);
1379
  if (err != CL_SUCCESS) {
1380
+ for (auto event : events) {
1381
+ clReleaseEvent(event);
1382
+ }
1383
+ return err;
1384
  }
1385
  }
1386
+ for (auto event : events) {
1387
+ CL_CHECK(clReleaseEvent(event));
1388
+ }
1389
+ return CL_SUCCESS;
1390
  }
1391
 
1392
  static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
1395
  const int64_t ne01 = src0->ne[1];
1396
  const int64_t ne02 = src0->ne[2];
1397
  const int64_t ne03 = src0->ne[3];
 
1398
  const int64_t ne10 = src1->ne[0];
1399
  const int64_t ne11 = src1->ne[1];
1400
  const int64_t ne12 = src1->ne[2];
1401
  const int64_t ne13 = src1->ne[3];
 
1402
  const int nb2 = dst->nb[2];
1403
  const int nb3 = dst->nb[3];
1404
  size_t x_size;
1405
  size_t d_size;
1406
 
1407
+ cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
1408
  cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
1409
+ cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
1410
 
1411
 
1412
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1413
  for (int64_t i02 = 0; i02 < ne02; i02++) {
 
 
1414
  cl_event ev;
1415
 
1416
  // copy src0 to device
1417
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1418
 
1419
+ const int64_t i13 = i03%ne13;
1420
+ const int64_t i12 = i02%ne12;
1421
+ const int i1 = i13*ne12*ne11 + i12*ne11;
1422
+
1423
+ cl_int x_offset = 0;
1424
+ cl_int y_offset = i1*ne10;
1425
+ cl_int d_offset = 0;
1426
+
1427
+ size_t global = ne00 * ne01;
1428
+ cl_int ky = ne10 * ne11;
1429
+
1430
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1431
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1432
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1433
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1434
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1435
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1436
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1437
+ CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1438
 
1439
  CL_CHECK(clReleaseEvent(ev));
1440
  CL_CHECK(clFinish(queue));
 
1461
 
1462
  const int64_t ne10 = src1->ne[0];
1463
  const int64_t ne11 = src1->ne[1];
1464
+ const int64_t ne12 = src1->ne[2];
1465
+ const int64_t ne13 = src1->ne[3];
1466
 
1467
  const int nb2 = dst->nb[2];
1468
  const int nb3 = dst->nb[3];
1469
 
1470
+ const int64_t r2 = ne12 / ne02;
1471
+ const int64_t r3 = ne13 / ne03;
1472
+
1473
  const float alpha = 1.0f;
1474
  const float beta = 0.0f;
1475
  const int x_ne = ne01 * ne00;
 
1488
  cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
1489
  cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
1490
 
1491
+ size_t x_offset = 0;
1492
+
1493
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1494
+ // TODO: copy src0 here when r3>1
1495
+ for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
1496
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
1497
+ if (src0->backend == GGML_BACKEND_GPU) {
1498
+ x_offset = (i03 * ne02 + i02) * x_ne;
1499
+ } else {
1500
+ // copy src0 to device
1501
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1502
+ }
1503
 
1504
+ for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
1505
+ // copy src1 to device
1506
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
1507
 
1508
+ CL_CHECK(clFinish(queue));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1509
 
1510
+ // compute
1511
+ cl_event ev_sgemm;
1512
+ clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
1513
+ clblast::Transpose::kYes, clblast::Transpose::kNo,
1514
+ ne01, ne11, ne10,
1515
+ alpha,
1516
+ d_X, x_offset, ne00,
1517
+ d_Y, 0, ne10,
1518
+ beta,
1519
+ d_D, 0, ne01,
1520
+ &queue, &ev_sgemm);
1521
+
1522
+ if (status != clblast::StatusCode::kSuccess) {
1523
+ GGML_ASSERT(false);
1524
+ }
1525
+
1526
+ // copy dst to host
1527
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
1528
+ CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
1529
+ }
1530
+ }
1531
  }
1532
  }
1533
 
 
1538
  ggml_cl_pool_free(d_D, d_size);
1539
  }
1540
 
1541
+ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
1542
  GGML_ASSERT(fp16_support);
1543
 
1544
  const int64_t ne00 = src0->ne[0];
 
1548
 
1549
  const int64_t ne10 = src1->ne[0];
1550
  const int64_t ne11 = src1->ne[1];
1551
+ const int64_t ne12 = src1->ne[2];
1552
+ const int64_t ne13 = src1->ne[3];
1553
 
1554
  const int nb10 = src1->nb[0];
1555
  const int nb11 = src1->nb[1];
 
1559
  const int nb2 = dst->nb[2];
1560
  const int nb3 = dst->nb[3];
1561
 
1562
+ const int64_t r2 = ne12 / ne02;
1563
+ const int64_t r3 = ne13 / ne03;
1564
+
1565
  const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f);
1566
  const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f);
1567
  const int x_ne = ne01 * ne00;
1568
  const int y_ne = ne11 * ne10;
1569
  const int d_ne = ne11 * ne01;
1570
 
1571
+ GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
1572
+ GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
1573
+ ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
1574
+
1575
  size_t x_size;
1576
  size_t y_size;
1577
  size_t d_size;
 
1587
  bool src1_cont_rows = nb10 == sizeof(float);
1588
  bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
1589
 
1590
+ size_t x_offset = 0;
 
 
 
 
 
1591
 
1592
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
1593
+ // TODO: copy src0 here when r3>1
1594
+ for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
1595
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
1596
+ if (src0->backend == GGML_BACKEND_GPU) {
1597
+ x_offset = (i03 * ne02 + i02) * x_ne;
1598
+ } else {
1599
+ // copy src0 to device
1600
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1601
  }
1602
+
1603
+ for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
1604
+ // convert src1 to fp16
1605
+ // TODO: use multiple threads
1606
+ char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
1607
+ if (src1_cont_rows) {
1608
+ if (src1_cont_cols) {
1609
+ ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
1610
+ }
1611
+ else {
1612
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
1613
+ ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10);
1614
+ }
1615
+ }
1616
  }
1617
+ else {
1618
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
1619
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
1620
+ // very slow due to no inlining
1621
+ tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10));
1622
+ }
1623
+ }
1624
  }
 
 
1625
 
1626
+ // copy src1 to device
1627
+ CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL));
1628
 
1629
+ CL_CHECK(clFinish(queue));
1630
 
1631
+ // compute
1632
+ cl_event ev_sgemm;
1633
+ clblast::StatusCode status = clblast::Gemm<cl_half>(clblast::Layout::kColMajor,
1634
+ clblast::Transpose::kYes, clblast::Transpose::kNo,
1635
+ ne01, ne11, ne10,
1636
+ alpha,
1637
+ d_X, x_offset, ne00,
1638
+ d_Y, 0, ne10,
1639
+ beta,
1640
+ d_D, 0, ne01,
1641
+ &queue, &ev_sgemm);
1642
+
1643
+ if (status != clblast::StatusCode::kSuccess) {
1644
+ GGML_ASSERT(false);
1645
+ }
1646
 
1647
+ // copy dst to host, then convert to float
1648
+ CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
1649
 
1650
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
1651
 
1652
+ ggml_fp16_to_fp32_row(tmp, d, d_ne);
1653
+ }
1654
+ }
1655
  }
1656
  }
1657
 
 
1670
 
1671
  const int64_t ne10 = src1->ne[0];
1672
  const int64_t ne11 = src1->ne[1];
1673
+ const int64_t ne12 = src1->ne[2];
1674
+ const int64_t ne13 = src1->ne[3];
1675
 
1676
  const int nb2 = dst->nb[2];
1677
  const int nb3 = dst->nb[3];
1678
  const ggml_type type = src0->type;
1679
+ const bool mul_mat_vec = ne11 == 1 && ne00%2 == 0;
1680
+
1681
+ const int64_t r2 = ne12 / ne02;
1682
+ const int64_t r3 = ne13 / ne03;
1683
 
1684
  const float alpha = 1.0f;
1685
  const float beta = 0.0f;
1686
  const int x_ne = ne01 * ne00;
1687
  const int y_ne = ne11 * ne10;
1688
  const int d_ne = ne11 * ne01;
1689
+ const int x_bps = x_ne / ggml_blck_size(type); // blocks per 2D slice
1690
+ const size_t q_sz = ggml_type_size(type) * x_bps;
1691
 
1692
  size_t x_size;
1693
  size_t y_size;
 
1709
  GGML_ASSERT(to_fp32_cl != nullptr);
1710
 
1711
  const size_t global_denom = ggml_cl_global_denom(type);
1712
+ const size_t local = mul_mat_vec ? CL_DMMV_LOCAL_SIZE : ggml_cl_local_size(type);
1713
 
1714
  size_t ev_idx = 0;
1715
  std::vector<cl_event> events;
1716
 
1717
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1718
+ // TODO: copy and dequantize src0 here when r3>1
1719
+ for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
1720
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
1721
+ // copy src0 to device if necessary
1722
+ if (src0->backend == GGML_BACKEND_CPU) {
1723
+ events.emplace_back();
1724
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
1725
+ } else if (src0->backend == GGML_BACKEND_GPU) {
1726
+ d_Q = (cl_mem) src0->extra;
1727
+ } else {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1728
  GGML_ASSERT(false);
1729
  }
 
1730
 
1731
+ if (!mul_mat_vec) {
1732
+ // convert src0 to fp32 on device
1733
+ const size_t global = x_ne / global_denom;
1734
+ const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
1735
+ CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
1736
+ CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
1737
+ CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, &offset, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
1738
+ }
1739
+
1740
+ for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
1741
+ if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
1742
+ // copy src1 to device
1743
+ events.emplace_back();
1744
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
1745
+
1746
+ // compute
1747
+ const size_t global = ne01 * local;
1748
+ const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
1749
+ const cl_int ncols = ne00;
1750
+ events.emplace_back();
1751
+ CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
1752
+ CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
1753
+ CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
1754
+ CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
1755
+ CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
1756
+ CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, &offset, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
1757
+ } else { // CLBlast matrix matrix multiplication
1758
+ // copy src1 to device
1759
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
1760
+
1761
+ // wait for conversion
1762
+ CL_CHECK(clFinish(queue));
1763
+
1764
+ // compute
1765
+ events.emplace_back();
1766
+ clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
1767
+ clblast::Transpose::kYes, clblast::Transpose::kNo,
1768
+ ne01, ne11, ne10,
1769
+ alpha,
1770
+ d_X, 0, ne00,
1771
+ d_Y, 0, ne10,
1772
+ beta,
1773
+ d_D, 0, ne01,
1774
+ &queue, events.data() + ev_idx++);
1775
+
1776
+ if (status != clblast::StatusCode::kSuccess) {
1777
+ GGML_ASSERT(false);
1778
+ }
1779
+ }
1780
 
1781
+ // copy dst to host
1782
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
1783
+ CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
1784
+ for (auto *event : events) {
1785
+ clReleaseEvent(event);
1786
+ }
1787
+
1788
+ ev_idx = 0;
1789
+ events.clear();
1790
+ }
1791
+ }
1792
  }
1793
  }
1794
 
 
1820
  return false;
1821
  }
1822
 
1823
+ static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1824
  // If device doesn't support FP16
1825
  if (!fp16_support) {
1826
  return false;
 
1863
  }
1864
 
1865
  size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1866
+ if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
1867
+ return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
1868
  }
1869
  return 0;
1870
  }
 
1876
  const int64_t ne3 = tensor->ne[3];
1877
 
1878
  const ggml_type type = tensor->type;
1879
+ const size_t s_sz = ggml_type_size(type) * (size_t) (ne0 * ne1 / ggml_blck_size(type));
1880
+ const size_t q_sz = s_sz * (size_t) (ne2 * ne3);
1881
 
1882
  size_t q_size;
1883
  cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size);
1884
 
1885
  tensor->data = data;
1886
  // copy tensor to device
1887
+ size_t offset = 0;
1888
  for (int64_t i3 = 0; i3 < ne3; i3++) {
1889
  for (int64_t i2 = 0; i2 < ne2; i2++) {
1890
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, offset, tensor, i3, i2, NULL));
1891
+ offset += s_sz;
1892
  }
1893
  }
1894
 
ggml-quants.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml-quants.h ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml-impl.h"
4
+
5
+ // GGML internal header
6
+
7
+ #include <stdint.h>
8
+ #include <stddef.h>
9
+
10
+ #define QK4_0 32
11
+ typedef struct {
12
+ ggml_fp16_t d; // delta
13
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
14
+ } block_q4_0;
15
+ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
16
+
17
+ #define QK4_1 32
18
+ typedef struct {
19
+ ggml_fp16_t d; // delta
20
+ ggml_fp16_t m; // min
21
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
22
+ } block_q4_1;
23
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
24
+
25
+ #define QK5_0 32
26
+ typedef struct {
27
+ ggml_fp16_t d; // delta
28
+ uint8_t qh[4]; // 5-th bit of quants
29
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
30
+ } block_q5_0;
31
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
32
+
33
+ #define QK5_1 32
34
+ typedef struct {
35
+ ggml_fp16_t d; // delta
36
+ ggml_fp16_t m; // min
37
+ uint8_t qh[4]; // 5-th bit of quants
38
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
39
+ } block_q5_1;
40
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
41
+
42
+ #define QK8_0 32
43
+ typedef struct {
44
+ ggml_fp16_t d; // delta
45
+ int8_t qs[QK8_0]; // quants
46
+ } block_q8_0;
47
+ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
48
+
49
+ #define QK8_1 32
50
+ typedef struct {
51
+ float d; // delta
52
+ float s; // d * sum(qs[i])
53
+ int8_t qs[QK8_1]; // quants
54
+ } block_q8_1;
55
+ static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
56
+
57
+ //
58
+ // Super-block quantization structures
59
+ //
60
+
61
+ // Super-block size
62
+ #ifdef GGML_QKK_64
63
+ #define QK_K 64
64
+ #define K_SCALE_SIZE 4
65
+ #else
66
+ #define QK_K 256
67
+ #define K_SCALE_SIZE 12
68
+ #endif
69
+
70
+ // 2-bit quantization
71
+ // weight is represented as x = a * q + b
72
+ // 16 blocks of 16 elements each
73
+ // Effectively 2.5625 bits per weight
74
+ typedef struct {
75
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
76
+ uint8_t qs[QK_K/4]; // quants
77
+ ggml_fp16_t d; // super-block scale for quantized scales
78
+ ggml_fp16_t dmin; // super-block scale for quantized mins
79
+ } block_q2_K;
80
+ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
81
+
82
+ // 3-bit quantization
83
+ // weight is represented as x = a * q
84
+ // 16 blocks of 16 elements each
85
+ // Effectively 3.4375 bits per weight
86
+ #ifdef GGML_QKK_64
87
+ typedef struct {
88
+ uint8_t hmask[QK_K/8]; // quants - high bit
89
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
90
+ uint8_t scales[2];
91
+ ggml_fp16_t d; // super-block scale
92
+ } block_q3_K;
93
+ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
94
+ #else
95
+ typedef struct {
96
+ uint8_t hmask[QK_K/8]; // quants - high bit
97
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
98
+ uint8_t scales[12]; // scales, quantized with 6 bits
99
+ ggml_fp16_t d; // super-block scale
100
+ } block_q3_K;
101
+ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
102
+ #endif
103
+
104
+ // 4-bit quantization
105
+ // 8 blocks of 32 elements each
106
+ // weight is represented as x = a * q + b
107
+ // Effectively 4.5 bits per weight
108
+ #ifdef GGML_QKK_64
109
+ typedef struct {
110
+ ggml_fp16_t d[2]; // super-block scales/mins
111
+ uint8_t scales[2]; // 4-bit block scales/mins
112
+ uint8_t qs[QK_K/2]; // 4--bit quants
113
+ } block_q4_K;
114
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
115
+ #else
116
+ typedef struct {
117
+ ggml_fp16_t d; // super-block scale for quantized scales
118
+ ggml_fp16_t dmin; // super-block scale for quantized mins
119
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
120
+ uint8_t qs[QK_K/2]; // 4--bit quants
121
+ } block_q4_K;
122
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
123
+ #endif
124
+
125
+ // 5-bit quantization
126
+ // 8 blocks of 32 elements each
127
+ // weight is represented as x = a * q + b
128
+ // Effectively 5.5 bits per weight
129
+ #ifdef GGML_QKK_64
130
+ typedef struct {
131
+ ggml_fp16_t d; // super-block scale
132
+ int8_t scales[QK_K/16]; // 8-bit block scales
133
+ uint8_t qh[QK_K/8]; // quants, high bit
134
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
135
+ } block_q5_K;
136
+ static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
137
+ #else
138
+ typedef struct {
139
+ ggml_fp16_t d; // super-block scale for quantized scales
140
+ ggml_fp16_t dmin; // super-block scale for quantized mins
141
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
142
+ uint8_t qh[QK_K/8]; // quants, high bit
143
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
144
+ } block_q5_K;
145
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
146
+ #endif
147
+
148
+ // 6-bit quantization
149
+ // weight is represented as x = a * q
150
+ // 16 blocks of 16 elements each
151
+ // Effectively 6.5625 bits per weight
152
+ typedef struct {
153
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
154
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
155
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
156
+ ggml_fp16_t d; // super-block scale
157
+ } block_q6_K;
158
+ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
159
+
160
+ // This is only used for intermediate quantization and dot products
161
+ typedef struct {
162
+ float d; // delta
163
+ int8_t qs[QK_K]; // quants
164
+ int16_t bsums[QK_K/16]; // sum of quants in groups of 16
165
+ } block_q8_K;
166
+ static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
167
+
168
+
169
+ // Quantization
170
+ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
171
+ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
172
+ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
173
+ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
174
+ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
175
+ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
176
+
177
+ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
178
+ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
179
+ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
180
+ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
181
+ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
182
+ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
183
+
184
+ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
185
+ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
186
+ void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
187
+ void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
188
+ void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
189
+ void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
190
+
191
+ void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
192
+ void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
193
+ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
194
+ void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
195
+ void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
196
+ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
197
+
198
+ // Dequantization
199
+ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
200
+ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
201
+ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
202
+ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
203
+ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
204
+ //void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
205
+
206
+ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
207
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
208
+ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
209
+ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
210
+ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
211
+ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
212
+
213
+ // Dot product
214
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
215
+ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
216
+ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
217
+ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
218
+ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
219
+
220
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
221
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
222
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
223
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
224
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
ggml.c CHANGED
The diff for this file is too large to render. See raw diff
 
ggml.h CHANGED
@@ -58,7 +58,8 @@
58
  // {
59
  // ...
60
  //
61
- // struct ggml_cgraph gf = ggml_build_forward(f);
 
62
  //
63
  // // set the input variable and parameter values
64
  // ggml_set_f32(x, 2.0f);
@@ -213,15 +214,14 @@
213
  #define GGML_QNT_VERSION 2 // bump this on quantization format changes
214
  #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
215
 
216
- #define GGML_MAX_DIMS 4
217
- #define GGML_MAX_NODES 4096
218
- #define GGML_MAX_PARAMS 256
219
- #define GGML_MAX_CONTEXTS 64
220
- #define GGML_MAX_SRC 6
221
- #define GGML_MAX_NAME 64
222
- #define GGML_MAX_OP_PARAMS 32
223
- #define GGML_DEFAULT_N_THREADS 4
224
-
225
  #if UINTPTR_MAX == 0xFFFFFFFF
226
  #define GGML_MEM_ALIGN 4
227
  #else
@@ -231,8 +231,9 @@
231
  #define GGML_EXIT_SUCCESS 0
232
  #define GGML_EXIT_ABORTED 1
233
 
234
- #define GGUF_MAGIC 0x46554747 // "GGUF"
235
- #define GGUF_VERSION 2
 
236
 
237
  #define GGUF_DEFAULT_ALIGNMENT 32
238
 
@@ -244,10 +245,21 @@
244
  do { \
245
  if (!(x)) { \
246
  fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
247
- abort(); \
 
 
 
248
  } \
249
  } while (0)
250
 
 
 
 
 
 
 
 
 
251
  // used to copy the number of elements and stride in bytes of tensors into local variables.
252
  // main purpose is to reduce code duplication and improve readability.
253
  //
@@ -318,7 +330,7 @@ extern "C" {
318
  GGML_TYPE_COUNT,
319
  };
320
 
321
- enum ggml_backend {
322
  GGML_BACKEND_CPU = 0,
323
  GGML_BACKEND_GPU = 10,
324
  GGML_BACKEND_GPU_SPLIT = 20,
@@ -392,7 +404,12 @@ extern "C" {
392
  GGML_OP_ALIBI,
393
  GGML_OP_CLAMP,
394
  GGML_OP_CONV_1D,
 
 
 
395
  GGML_OP_CONV_2D,
 
 
396
  GGML_OP_CONV_TRANSPOSE_2D,
397
  GGML_OP_POOL_1D,
398
  GGML_OP_POOL_2D,
@@ -437,6 +454,7 @@ extern "C" {
437
  GGML_UNARY_OP_GELU,
438
  GGML_UNARY_OP_GELU_QUICK,
439
  GGML_UNARY_OP_SILU,
 
440
  };
441
 
442
  enum ggml_object_type {
@@ -445,6 +463,12 @@ extern "C" {
445
  GGML_OBJECT_WORK_BUFFER
446
  };
447
 
 
 
 
 
 
 
448
  // ggml object
449
  struct ggml_object {
450
  size_t offs;
@@ -461,14 +485,16 @@ extern "C" {
461
 
462
  // n-dimensional tensor
463
  struct ggml_tensor {
464
- enum ggml_type type;
465
- enum ggml_backend backend;
 
 
466
 
467
  int n_dims;
468
  int64_t ne[GGML_MAX_DIMS]; // number of elements
469
  size_t nb[GGML_MAX_DIMS]; // stride in bytes:
470
- // nb[0] = sizeof(type)
471
- // nb[1] = nb[0] * ne[0] + padding
472
  // nb[i] = nb[i-1] * ne[i-1]
473
 
474
  // compute data
@@ -496,7 +522,7 @@ extern "C" {
496
 
497
  void * extra; // extra things e.g. for ggml-cuda.cu
498
 
499
- char padding[4];
500
  };
501
 
502
  static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -509,29 +535,35 @@ extern "C" {
509
 
510
  int n_threads;
511
 
512
- // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
513
- int n_tasks[GGML_MAX_NODES];
514
-
515
  // abort ggml_graph_compute when true
516
  bool (*abort_callback)(void * data);
517
  void * abort_callback_data;
518
  };
519
 
520
- // next prime after GGML_MAX_NODES
521
- // #define GGML_GRAPH_HASHTABLE_SIZE 4099
522
- // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
523
- #define GGML_GRAPH_HASHTABLE_SIZE 8273
 
 
 
 
 
 
524
 
525
  // computation graph
526
  struct ggml_cgraph {
 
527
  int n_nodes;
528
  int n_leafs;
529
 
530
- struct ggml_tensor * nodes[GGML_MAX_NODES];
531
- struct ggml_tensor * grads[GGML_MAX_NODES];
532
- struct ggml_tensor * leafs[GGML_MAX_NODES];
 
 
533
 
534
- void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
535
 
536
  // performance
537
  int perf_runs;
@@ -539,8 +571,6 @@ extern "C" {
539
  int64_t perf_time_us;
540
  };
541
 
542
- static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
543
-
544
  // scratch buffer
545
  struct ggml_scratch {
546
  size_t offs;
@@ -585,6 +615,8 @@ extern "C" {
585
  GGML_API int64_t ggml_cycles(void);
586
  GGML_API int64_t ggml_cycles_per_ms(void);
587
 
 
 
588
  GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems
589
  GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
590
 
@@ -674,18 +706,30 @@ extern "C" {
674
  GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
675
  GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
676
 
 
 
 
677
  GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
678
 
679
  GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
680
  GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
681
  GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
682
 
 
 
 
683
  GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
684
  GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
685
 
 
 
 
686
  GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
687
  GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
688
 
 
 
 
689
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
690
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
691
 
@@ -719,6 +763,12 @@ extern "C" {
719
  struct ggml_tensor * a,
720
  struct ggml_tensor * b);
721
 
 
 
 
 
 
 
722
  GGML_API struct ggml_tensor * ggml_add1(
723
  struct ggml_context * ctx,
724
  struct ggml_tensor * a,
@@ -828,6 +878,7 @@ extern "C" {
828
  struct ggml_tensor * a,
829
  struct ggml_tensor * b);
830
 
 
831
  GGML_API struct ggml_tensor * ggml_repeat_back(
832
  struct ggml_context * ctx,
833
  struct ggml_tensor * a,
@@ -892,6 +943,10 @@ extern "C" {
892
  struct ggml_context * ctx,
893
  struct ggml_tensor * a);
894
 
 
 
 
 
895
  GGML_API struct ggml_tensor * ggml_relu_inplace(
896
  struct ggml_context * ctx,
897
  struct ggml_tensor * a);
@@ -970,9 +1025,9 @@ extern "C" {
970
  struct ggml_tensor * b,
971
  float eps);
972
 
973
- // A: n columns, m rows
974
- // B: n columns, p rows (i.e. we transpose it internally)
975
- // result is m columns, p rows
976
  GGML_API struct ggml_tensor * ggml_mul_mat(
977
  struct ggml_context * ctx,
978
  struct ggml_tensor * a,
@@ -1049,7 +1104,6 @@ extern "C" {
1049
  size_t nb1,
1050
  size_t offset);
1051
 
1052
-
1053
  // a -> b, return view(b)
1054
  GGML_API struct ggml_tensor * ggml_cpy(
1055
  struct ggml_context * ctx,
@@ -1072,6 +1126,33 @@ extern "C" {
1072
  struct ggml_context * ctx,
1073
  struct ggml_tensor * a);
1074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1075
  // return view(a), b specifies the new shape
1076
  // TODO: when we start computing gradient, make a copy instead of view
1077
  GGML_API struct ggml_tensor * ggml_reshape(
@@ -1219,14 +1300,15 @@ extern "C" {
1219
  struct ggml_tensor * b);
1220
 
1221
  // rotary position embedding
1222
- // if mode & 1 == 1, skip n_past elements
1223
  // if mode & 2 == 1, GPT-NeoX style
1224
  // if mode & 4 == 1, ChatGLM style
1225
- // TODO: avoid creating a new tensor every time
 
1226
  GGML_API struct ggml_tensor * ggml_rope(
1227
  struct ggml_context * ctx,
1228
  struct ggml_tensor * a,
1229
- int n_past,
1230
  int n_dims,
1231
  int mode,
1232
  int n_ctx);
@@ -1235,7 +1317,7 @@ extern "C" {
1235
  GGML_API struct ggml_tensor * ggml_rope_inplace(
1236
  struct ggml_context * ctx,
1237
  struct ggml_tensor * a,
1238
- int n_past,
1239
  int n_dims,
1240
  int mode,
1241
  int n_ctx);
@@ -1244,29 +1326,43 @@ extern "C" {
1244
  GGML_API struct ggml_tensor * ggml_rope_custom(
1245
  struct ggml_context * ctx,
1246
  struct ggml_tensor * a,
1247
- int n_past,
1248
  int n_dims,
1249
  int mode,
1250
  int n_ctx,
 
1251
  float freq_base,
1252
- float freq_scale);
 
 
 
 
1253
 
1254
  // in-place, returns view(a)
1255
  GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1256
  struct ggml_context * ctx,
1257
  struct ggml_tensor * a,
1258
- int n_past,
1259
  int n_dims,
1260
  int mode,
1261
  int n_ctx,
 
1262
  float freq_base,
1263
- float freq_scale);
 
 
 
 
 
 
 
 
1264
 
1265
  // xPos RoPE, in-place, returns view(a)
1266
  GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
1267
  struct ggml_context * ctx,
1268
  struct ggml_tensor * a,
1269
- int n_past,
1270
  int n_dims,
1271
  float base,
1272
  bool down);
@@ -1276,7 +1372,7 @@ extern "C" {
1276
  GGML_API struct ggml_tensor * ggml_rope_back(
1277
  struct ggml_context * ctx,
1278
  struct ggml_tensor * a,
1279
- int n_past,
1280
  int n_dims,
1281
  int mode,
1282
  int n_ctx,
@@ -1287,7 +1383,7 @@ extern "C" {
1287
 
1288
  // alibi position embedding
1289
  // in-place, returns view(a)
1290
- struct ggml_tensor * ggml_alibi(
1291
  struct ggml_context * ctx,
1292
  struct ggml_tensor * a,
1293
  int n_past,
@@ -1296,7 +1392,7 @@ extern "C" {
1296
 
1297
  // clamp
1298
  // in-place, returns view(a)
1299
- struct ggml_tensor * ggml_clamp(
1300
  struct ggml_context * ctx,
1301
  struct ggml_tensor * a,
1302
  float min,
@@ -1319,6 +1415,14 @@ extern "C" {
1319
  int s,
1320
  int d);
1321
 
 
 
 
 
 
 
 
 
1322
  GGML_API struct ggml_tensor * ggml_conv_2d(
1323
  struct ggml_context * ctx,
1324
  struct ggml_tensor * a,
@@ -1377,6 +1481,8 @@ extern "C" {
1377
  int s0, // stride
1378
  int p0); // padding
1379
 
 
 
1380
  GGML_API struct ggml_tensor * ggml_pool_2d(
1381
  struct ggml_context * ctx,
1382
  struct ggml_tensor * a,
@@ -1385,8 +1491,8 @@ extern "C" {
1385
  int k1,
1386
  int s0,
1387
  int s1,
1388
- int p0,
1389
- int p1);
1390
 
1391
  // nearest interpolate
1392
  // used in stable-diffusion
@@ -1627,19 +1733,22 @@ extern "C" {
1627
  GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1628
  GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
1629
 
1630
- GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
1631
- GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
1632
-
1633
  // graph allocation in a context
1634
- GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx);
1635
- GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
 
 
 
 
 
 
1636
  GGML_API size_t ggml_graph_overhead(void);
 
1637
 
1638
  // ggml_graph_plan() has to be called before ggml_graph_compute()
1639
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
1640
  GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
1641
- GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
1642
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
1643
 
1644
  // same as ggml_graph_compute() but the work data is allocated as a part of the context
1645
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
@@ -1647,8 +1756,8 @@ extern "C" {
1647
 
1648
  GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
1649
 
1650
- GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
1651
- GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
1652
 
1653
  // print info and performance information for the graph
1654
  GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
@@ -1656,6 +1765,16 @@ extern "C" {
1656
  // dump the graph into a file using the dot format
1657
  GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
1658
 
 
 
 
 
 
 
 
 
 
 
1659
  //
1660
  // optimization
1661
  //
@@ -1682,6 +1801,7 @@ extern "C" {
1682
  GGML_OPT_NO_CONTEXT,
1683
  GGML_OPT_INVALID_WOLFE,
1684
  GGML_OPT_FAIL,
 
1685
 
1686
  GGML_LINESEARCH_FAIL = -128,
1687
  GGML_LINESEARCH_MINIMUM_STEP,
@@ -1690,7 +1810,8 @@ extern "C" {
1690
  GGML_LINESEARCH_INVALID_PARAMETERS,
1691
  };
1692
 
1693
- typedef void (*ggml_opt_callback)(void * data, float * sched);
 
1694
 
1695
  // optimization parameters
1696
  //
@@ -1699,6 +1820,8 @@ extern "C" {
1699
  struct ggml_opt_params {
1700
  enum ggml_opt_type type;
1701
 
 
 
1702
  int n_threads;
1703
 
1704
  // delta-based convergence test
@@ -1721,6 +1844,8 @@ extern "C" {
1721
  bool print_forward_graph;
1722
  bool print_backward_graph;
1723
 
 
 
1724
  // ADAM parameters
1725
  struct {
1726
  int n_iter;
@@ -1766,6 +1891,7 @@ extern "C" {
1766
  float loss_after;
1767
 
1768
  struct {
 
1769
  struct ggml_tensor * m; // first moment
1770
  struct ggml_tensor * v; // second moment
1771
  struct ggml_tensor * pf; // past function values
@@ -1829,12 +1955,19 @@ extern "C" {
1829
  // quantization
1830
  //
1831
 
 
1832
  GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1833
  GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1834
  GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1835
  GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1836
  GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1837
 
 
 
 
 
 
 
1838
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1839
 
1840
  //
@@ -1882,26 +2015,26 @@ extern "C" {
1882
 
1883
  GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
1884
  GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
1885
- GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
1886
-
1887
- GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
1888
- GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
1889
-
1890
- // results are undefined if the wrong type is used for the key
1891
- GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
1892
- GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
1893
- GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
1894
- GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
1895
- GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
1896
- GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
1897
- GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
1898
- GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
1899
- GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
1900
- GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
1901
- GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
1902
- GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
1903
- GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
1904
- GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
1905
  GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
1906
 
1907
  GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
@@ -2008,7 +2141,7 @@ extern "C" {
2008
  enum ggml_type vec_dot_type;
2009
  } ggml_type_traits_t;
2010
 
2011
- ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
2012
 
2013
  #ifdef __cplusplus
2014
  }
 
58
  // {
59
  // ...
60
  //
61
+ // struct ggml_cgraph * gf = ggml_new_graph(ctx);
62
+ // ggml_build_forward_expand(gf, f);
63
  //
64
  // // set the input variable and parameter values
65
  // ggml_set_f32(x, 2.0f);
 
214
  #define GGML_QNT_VERSION 2 // bump this on quantization format changes
215
  #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
216
 
217
+ #define GGML_MAX_DIMS 4
218
+ #define GGML_MAX_PARAMS 1024
219
+ #define GGML_MAX_CONTEXTS 64
220
+ #define GGML_MAX_SRC 6
221
+ #define GGML_MAX_NAME 64
222
+ #define GGML_MAX_OP_PARAMS 64
223
+ #define GGML_DEFAULT_N_THREADS 4
224
+ #define GGML_DEFAULT_GRAPH_SIZE 2048
 
225
  #if UINTPTR_MAX == 0xFFFFFFFF
226
  #define GGML_MEM_ALIGN 4
227
  #else
 
231
  #define GGML_EXIT_SUCCESS 0
232
  #define GGML_EXIT_ABORTED 1
233
 
234
+ #define GGUF_MAGIC "GGUF"
235
+
236
+ #define GGUF_VERSION 3
237
 
238
  #define GGUF_DEFAULT_ALIGNMENT 32
239
 
 
245
  do { \
246
  if (!(x)) { \
247
  fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
248
+ fflush(stderr); \
249
+ fflush(stdout); \
250
+ ggml_print_backtrace(); \
251
+ exit(1); \
252
  } \
253
  } while (0)
254
 
255
+ #ifndef NDEBUG
256
+ #define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
257
+ #elif defined(__GNUC__)
258
+ #define GGML_UNREACHABLE() __builtin_unreachable()
259
+ #else
260
+ #define GGML_UNREACHABLE() ((void) 0)
261
+ #endif
262
+
263
  // used to copy the number of elements and stride in bytes of tensors into local variables.
264
  // main purpose is to reduce code duplication and improve readability.
265
  //
 
330
  GGML_TYPE_COUNT,
331
  };
332
 
333
+ enum ggml_backend_type {
334
  GGML_BACKEND_CPU = 0,
335
  GGML_BACKEND_GPU = 10,
336
  GGML_BACKEND_GPU_SPLIT = 20,
 
404
  GGML_OP_ALIBI,
405
  GGML_OP_CLAMP,
406
  GGML_OP_CONV_1D,
407
+ GGML_OP_CONV_1D_STAGE_0, // internal
408
+ GGML_OP_CONV_1D_STAGE_1, // internal
409
+ GGML_OP_CONV_TRANSPOSE_1D,
410
  GGML_OP_CONV_2D,
411
+ GGML_OP_CONV_2D_STAGE_0, // internal
412
+ GGML_OP_CONV_2D_STAGE_1, // internal
413
  GGML_OP_CONV_TRANSPOSE_2D,
414
  GGML_OP_POOL_1D,
415
  GGML_OP_POOL_2D,
 
454
  GGML_UNARY_OP_GELU,
455
  GGML_UNARY_OP_GELU_QUICK,
456
  GGML_UNARY_OP_SILU,
457
+ GGML_UNARY_OP_LEAKY
458
  };
459
 
460
  enum ggml_object_type {
 
463
  GGML_OBJECT_WORK_BUFFER
464
  };
465
 
466
+ enum ggml_log_level {
467
+ GGML_LOG_LEVEL_ERROR = 2,
468
+ GGML_LOG_LEVEL_WARN = 3,
469
+ GGML_LOG_LEVEL_INFO = 4
470
+ };
471
+
472
  // ggml object
473
  struct ggml_object {
474
  size_t offs;
 
485
 
486
  // n-dimensional tensor
487
  struct ggml_tensor {
488
+ enum ggml_type type;
489
+ enum ggml_backend_type backend;
490
+
491
+ struct ggml_backend_buffer * buffer;
492
 
493
  int n_dims;
494
  int64_t ne[GGML_MAX_DIMS]; // number of elements
495
  size_t nb[GGML_MAX_DIMS]; // stride in bytes:
496
+ // nb[0] = ggml_type_size(type)
497
+ // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding
498
  // nb[i] = nb[i-1] * ne[i-1]
499
 
500
  // compute data
 
522
 
523
  void * extra; // extra things e.g. for ggml-cuda.cu
524
 
525
+ char padding[12];
526
  };
527
 
528
  static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
 
535
 
536
  int n_threads;
537
 
 
 
 
538
  // abort ggml_graph_compute when true
539
  bool (*abort_callback)(void * data);
540
  void * abort_callback_data;
541
  };
542
 
543
+ enum ggml_cgraph_eval_order {
544
+ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
545
+ GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
546
+ GGML_CGRAPH_EVAL_ORDER_COUNT
547
+ };
548
+
549
+ struct ggml_hash_set {
550
+ size_t size;
551
+ struct ggml_tensor ** keys;
552
+ };
553
 
554
  // computation graph
555
  struct ggml_cgraph {
556
+ int size;
557
  int n_nodes;
558
  int n_leafs;
559
 
560
+ struct ggml_tensor ** nodes;
561
+ struct ggml_tensor ** grads;
562
+ struct ggml_tensor ** leafs;
563
+
564
+ struct ggml_hash_set visited_hash_table;
565
 
566
+ enum ggml_cgraph_eval_order order;
567
 
568
  // performance
569
  int perf_runs;
 
571
  int64_t perf_time_us;
572
  };
573
 
 
 
574
  // scratch buffer
575
  struct ggml_scratch {
576
  size_t offs;
 
615
  GGML_API int64_t ggml_cycles(void);
616
  GGML_API int64_t ggml_cycles_per_ms(void);
617
 
618
+ GGML_API void ggml_print_backtrace(void);
619
+
620
  GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems
621
  GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
622
 
 
706
  GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
707
  GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
708
 
709
+ // Context tensor enumeration and lookup
710
+ GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
711
+ GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
712
  GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
713
 
714
  GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
715
  GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
716
  GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
717
 
718
+ // Converts a flat index into coordinates
719
+ GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
720
+
721
  GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
722
  GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
723
 
724
+ GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
725
+ GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
726
+
727
  GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
728
  GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
729
 
730
+ GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
731
+ GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
732
+
733
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
734
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
735
 
 
763
  struct ggml_tensor * a,
764
  struct ggml_tensor * b);
765
 
766
+ GGML_API struct ggml_tensor * ggml_add_cast(
767
+ struct ggml_context * ctx,
768
+ struct ggml_tensor * a,
769
+ struct ggml_tensor * b,
770
+ enum ggml_type type);
771
+
772
  GGML_API struct ggml_tensor * ggml_add1(
773
  struct ggml_context * ctx,
774
  struct ggml_tensor * a,
 
878
  struct ggml_tensor * a,
879
  struct ggml_tensor * b);
880
 
881
+ // sums repetitions in a into shape of b
882
  GGML_API struct ggml_tensor * ggml_repeat_back(
883
  struct ggml_context * ctx,
884
  struct ggml_tensor * a,
 
943
  struct ggml_context * ctx,
944
  struct ggml_tensor * a);
945
 
946
+ GGML_API struct ggml_tensor * ggml_leaky(
947
+ struct ggml_context * ctx,
948
+ struct ggml_tensor * a);
949
+
950
  GGML_API struct ggml_tensor * ggml_relu_inplace(
951
  struct ggml_context * ctx,
952
  struct ggml_tensor * a);
 
1025
  struct ggml_tensor * b,
1026
  float eps);
1027
 
1028
+ // A: k columns, n rows => [ne03, ne02, n, k]
1029
+ // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
1030
+ // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
1031
  GGML_API struct ggml_tensor * ggml_mul_mat(
1032
  struct ggml_context * ctx,
1033
  struct ggml_tensor * a,
 
1104
  size_t nb1,
1105
  size_t offset);
1106
 
 
1107
  // a -> b, return view(b)
1108
  GGML_API struct ggml_tensor * ggml_cpy(
1109
  struct ggml_context * ctx,
 
1126
  struct ggml_context * ctx,
1127
  struct ggml_tensor * a);
1128
 
1129
+ // make contiguous, with new shape
1130
+ GGML_API struct ggml_tensor * ggml_cont_1d(
1131
+ struct ggml_context * ctx,
1132
+ struct ggml_tensor * a,
1133
+ int64_t ne0);
1134
+
1135
+ GGML_API struct ggml_tensor * ggml_cont_2d(
1136
+ struct ggml_context * ctx,
1137
+ struct ggml_tensor * a,
1138
+ int64_t ne0,
1139
+ int64_t ne1);
1140
+
1141
+ GGML_API struct ggml_tensor * ggml_cont_3d(
1142
+ struct ggml_context * ctx,
1143
+ struct ggml_tensor * a,
1144
+ int64_t ne0,
1145
+ int64_t ne1,
1146
+ int64_t ne2);
1147
+
1148
+ GGML_API struct ggml_tensor * ggml_cont_4d(
1149
+ struct ggml_context * ctx,
1150
+ struct ggml_tensor * a,
1151
+ int64_t ne0,
1152
+ int64_t ne1,
1153
+ int64_t ne2,
1154
+ int64_t ne3);
1155
+
1156
  // return view(a), b specifies the new shape
1157
  // TODO: when we start computing gradient, make a copy instead of view
1158
  GGML_API struct ggml_tensor * ggml_reshape(
 
1300
  struct ggml_tensor * b);
1301
 
1302
  // rotary position embedding
1303
+ // if mode & 1 == 1, skip n_past elements (DEPRECATED)
1304
  // if mode & 2 == 1, GPT-NeoX style
1305
  // if mode & 4 == 1, ChatGLM style
1306
+ //
1307
+ // b is an int32 vector with size a->ne[2], it contains the positions
1308
  GGML_API struct ggml_tensor * ggml_rope(
1309
  struct ggml_context * ctx,
1310
  struct ggml_tensor * a,
1311
+ struct ggml_tensor * b,
1312
  int n_dims,
1313
  int mode,
1314
  int n_ctx);
 
1317
  GGML_API struct ggml_tensor * ggml_rope_inplace(
1318
  struct ggml_context * ctx,
1319
  struct ggml_tensor * a,
1320
+ struct ggml_tensor * b,
1321
  int n_dims,
1322
  int mode,
1323
  int n_ctx);
 
1326
  GGML_API struct ggml_tensor * ggml_rope_custom(
1327
  struct ggml_context * ctx,
1328
  struct ggml_tensor * a,
1329
+ struct ggml_tensor * b,
1330
  int n_dims,
1331
  int mode,
1332
  int n_ctx,
1333
+ int n_orig_ctx,
1334
  float freq_base,
1335
+ float freq_scale,
1336
+ float ext_factor,
1337
+ float attn_factor,
1338
+ float beta_fast,
1339
+ float beta_slow);
1340
 
1341
  // in-place, returns view(a)
1342
  GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1343
  struct ggml_context * ctx,
1344
  struct ggml_tensor * a,
1345
+ struct ggml_tensor * b,
1346
  int n_dims,
1347
  int mode,
1348
  int n_ctx,
1349
+ int n_orig_ctx,
1350
  float freq_base,
1351
+ float freq_scale,
1352
+ float ext_factor,
1353
+ float attn_factor,
1354
+ float beta_fast,
1355
+ float beta_slow);
1356
+
1357
+ // compute correction dims for YaRN RoPE scaling
1358
+ void ggml_rope_yarn_corr_dims(
1359
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1360
 
1361
  // xPos RoPE, in-place, returns view(a)
1362
  GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
1363
  struct ggml_context * ctx,
1364
  struct ggml_tensor * a,
1365
+ struct ggml_tensor * b,
1366
  int n_dims,
1367
  float base,
1368
  bool down);
 
1372
  GGML_API struct ggml_tensor * ggml_rope_back(
1373
  struct ggml_context * ctx,
1374
  struct ggml_tensor * a,
1375
+ struct ggml_tensor * b,
1376
  int n_dims,
1377
  int mode,
1378
  int n_ctx,
 
1383
 
1384
  // alibi position embedding
1385
  // in-place, returns view(a)
1386
+ GGML_API struct ggml_tensor * ggml_alibi(
1387
  struct ggml_context * ctx,
1388
  struct ggml_tensor * a,
1389
  int n_past,
 
1392
 
1393
  // clamp
1394
  // in-place, returns view(a)
1395
+ GGML_API struct ggml_tensor * ggml_clamp(
1396
  struct ggml_context * ctx,
1397
  struct ggml_tensor * a,
1398
  float min,
 
1415
  int s,
1416
  int d);
1417
 
1418
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
1419
+ struct ggml_context * ctx,
1420
+ struct ggml_tensor * a,
1421
+ struct ggml_tensor * b,
1422
+ int s0,
1423
+ int p0,
1424
+ int d0);
1425
+
1426
  GGML_API struct ggml_tensor * ggml_conv_2d(
1427
  struct ggml_context * ctx,
1428
  struct ggml_tensor * a,
 
1481
  int s0, // stride
1482
  int p0); // padding
1483
 
1484
+ // the result will have 2*p0 padding for the first dimension
1485
+ // and 2*p1 padding for the second dimension
1486
  GGML_API struct ggml_tensor * ggml_pool_2d(
1487
  struct ggml_context * ctx,
1488
  struct ggml_tensor * a,
 
1491
  int k1,
1492
  int s0,
1493
  int s1,
1494
+ float p0,
1495
+ float p1);
1496
 
1497
  // nearest interpolate
1498
  // used in stable-diffusion
 
1733
  GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1734
  GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
1735
 
 
 
 
1736
  // graph allocation in a context
1737
+ GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
1738
+ GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
1739
+ GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
1740
+ GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
1741
+ GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
1742
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
1743
+ GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
1744
+
1745
  GGML_API size_t ggml_graph_overhead(void);
1746
+ GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
1747
 
1748
  // ggml_graph_plan() has to be called before ggml_graph_compute()
1749
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
1750
  GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
1751
+ GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
 
1752
 
1753
  // same as ggml_graph_compute() but the work data is allocated as a part of the context
1754
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
 
1756
 
1757
  GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
1758
 
1759
+ GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
1760
+ GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
1761
 
1762
  // print info and performance information for the graph
1763
  GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
 
1765
  // dump the graph into a file using the dot format
1766
  GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
1767
 
1768
+ // build gradient checkpointing backward graph gb for gf using provided checkpoints
1769
+ // gb_tmp will contain original backward graph with rewritten backward process nodes,
1770
+ // but without the second forward pass nodes.
1771
+ GGML_API void ggml_build_backward_gradient_checkpointing(
1772
+ struct ggml_context * ctx,
1773
+ struct ggml_cgraph * gf,
1774
+ struct ggml_cgraph * gb,
1775
+ struct ggml_cgraph * gb_tmp,
1776
+ struct ggml_tensor * * checkpoints,
1777
+ int n_checkpoints);
1778
  //
1779
  // optimization
1780
  //
 
1801
  GGML_OPT_NO_CONTEXT,
1802
  GGML_OPT_INVALID_WOLFE,
1803
  GGML_OPT_FAIL,
1804
+ GGML_OPT_CANCEL,
1805
 
1806
  GGML_LINESEARCH_FAIL = -128,
1807
  GGML_LINESEARCH_MINIMUM_STEP,
 
1810
  GGML_LINESEARCH_INVALID_PARAMETERS,
1811
  };
1812
 
1813
+ typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
1814
+ typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
1815
 
1816
  // optimization parameters
1817
  //
 
1820
  struct ggml_opt_params {
1821
  enum ggml_opt_type type;
1822
 
1823
+ size_t graph_size;
1824
+
1825
  int n_threads;
1826
 
1827
  // delta-based convergence test
 
1844
  bool print_forward_graph;
1845
  bool print_backward_graph;
1846
 
1847
+ int n_gradient_accumulation;
1848
+
1849
  // ADAM parameters
1850
  struct {
1851
  int n_iter;
 
1891
  float loss_after;
1892
 
1893
  struct {
1894
+ struct ggml_tensor * g; // current gradient
1895
  struct ggml_tensor * m; // first moment
1896
  struct ggml_tensor * v; // second moment
1897
  struct ggml_tensor * pf; // past function values
 
1955
  // quantization
1956
  //
1957
 
1958
+ // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk
1959
  GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1960
  GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1961
  GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1962
  GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1963
  GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1964
 
1965
+ GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
1966
+ GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
1967
+ GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
1968
+ GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
1969
+ GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
1970
+
1971
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1972
 
1973
  //
 
2015
 
2016
  GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
2017
  GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
2018
+ GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
2019
+
2020
+ GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
2021
+ GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
2022
+
2023
+ // will abort if the wrong type is used for the key
2024
+ GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id);
2025
+ GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id);
2026
+ GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
2027
+ GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
2028
+ GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
2029
+ GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
2030
+ GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
2031
+ GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
2032
+ GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
2033
+ GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
2034
+ GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
2035
+ GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
2036
+ GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
2037
+ GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
2038
  GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
2039
 
2040
  GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
 
2141
  enum ggml_type vec_dot_type;
2142
  } ggml_type_traits_t;
2143
 
2144
+ GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
2145
 
2146
  #ifdef __cplusplus
2147
  }
whisper.cpp CHANGED
@@ -120,6 +120,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
120
  //#define WHISPER_USE_FLASH_ATTN
121
  //#define WHISPER_USE_FLASH_FF
122
  #define WHISPER_MAX_DECODERS 16
 
123
 
124
  //
125
  // ggml helpers
@@ -663,7 +664,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::funct
663
  auto & meta = allocr.meta;
664
  auto & data = allocr.data;
665
 
666
- meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
667
 
668
  alloc = ggml_allocr_new_measure(tensor_alignment);
669
 
@@ -1616,7 +1617,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1616
 
1617
  struct ggml_context * ctx0 = ggml_init(params);
1618
 
1619
- ggml_cgraph * gf = ggml_new_graph(ctx0);
1620
 
1621
  ggml_allocr * alloc = wstate.alloc_encode.alloc;
1622
 
@@ -2034,7 +2035,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2034
 
2035
  struct ggml_context * ctx0 = ggml_init(params);
2036
 
2037
- ggml_cgraph * gf = ggml_new_graph(ctx0);
2038
 
2039
  ggml_allocr * alloc = wstate.alloc_decode.alloc;
2040
 
@@ -3773,9 +3774,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3773
  /*.encoder_begin_callback =*/ nullptr,
3774
  /*.encoder_begin_callback_user_data =*/ nullptr,
3775
 
3776
- /*.abort_callback =*/ nullptr,
3777
- /*.abort_callback_user_data =*/ nullptr,
3778
-
3779
  /*.logits_filter_callback =*/ nullptr,
3780
  /*.logits_filter_callback_user_data =*/ nullptr,
3781
  };
@@ -4535,7 +4533,7 @@ int whisper_full_with_state(
4535
 
4536
  // initial prompt
4537
  if (!params.prompt_tokens && params.initial_prompt) {
4538
- prompt_tokens.resize(2048);
4539
  prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
4540
  params.prompt_tokens = prompt_tokens.data();
4541
  params.prompt_n_tokens = prompt_tokens.size();
@@ -5432,7 +5430,7 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5432
  // b: N*N*sizeof(float)
5433
  // c: N*N*sizeof(float)
5434
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5435
- std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
5436
  std::vector<uint8_t> work;
5437
 
5438
  // put a bunch of random data in the buffer
@@ -5483,17 +5481,19 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5483
 
5484
  struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
5485
 
5486
- struct ggml_cgraph gf = ggml_build_forward(c);
 
 
5487
 
5488
  double tsum = 0.0;
5489
 
5490
  // heat-up
5491
- ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
5492
 
5493
  for (int i = 0; i < n_max; ++i) {
5494
  const int64_t t0 = ggml_time_us();
5495
 
5496
- ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
5497
 
5498
  const int64_t t1 = ggml_time_us();
5499
 
 
120
  //#define WHISPER_USE_FLASH_ATTN
121
  //#define WHISPER_USE_FLASH_FF
122
  #define WHISPER_MAX_DECODERS 16
123
+ #define WHISPER_MAX_NODES 4096
124
 
125
  //
126
  // ggml helpers
 
664
  auto & meta = allocr.meta;
665
  auto & data = allocr.data;
666
 
667
+ meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
668
 
669
  alloc = ggml_allocr_new_measure(tensor_alignment);
670
 
 
1617
 
1618
  struct ggml_context * ctx0 = ggml_init(params);
1619
 
1620
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1621
 
1622
  ggml_allocr * alloc = wstate.alloc_encode.alloc;
1623
 
 
2035
 
2036
  struct ggml_context * ctx0 = ggml_init(params);
2037
 
2038
+ ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
2039
 
2040
  ggml_allocr * alloc = wstate.alloc_decode.alloc;
2041
 
 
3774
  /*.encoder_begin_callback =*/ nullptr,
3775
  /*.encoder_begin_callback_user_data =*/ nullptr,
3776
 
 
 
 
3777
  /*.logits_filter_callback =*/ nullptr,
3778
  /*.logits_filter_callback_user_data =*/ nullptr,
3779
  };
 
4533
 
4534
  // initial prompt
4535
  if (!params.prompt_tokens && params.initial_prompt) {
4536
+ prompt_tokens.resize(1024);
4537
  prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
4538
  params.prompt_tokens = prompt_tokens.data();
4539
  params.prompt_n_tokens = prompt_tokens.size();
 
5430
  // b: N*N*sizeof(float)
5431
  // c: N*N*sizeof(float)
5432
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5433
+ std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
5434
  std::vector<uint8_t> work;
5435
 
5436
  // put a bunch of random data in the buffer
 
5481
 
5482
  struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
5483
 
5484
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
5485
+
5486
+ ggml_build_forward_expand(gf, c);
5487
 
5488
  double tsum = 0.0;
5489
 
5490
  // heat-up
5491
+ ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
5492
 
5493
  for (int i = 0; i < n_max; ++i) {
5494
  const int64_t t0 = ggml_time_us();
5495
 
5496
+ ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
5497
 
5498
  const int64_t t1 = ggml_time_us();
5499