ggerganov commited on
Commit
7343760
·
1 Parent(s): bfe2a5f

ggml : move 32-bit arm compat in ggml-impl.h (llama/6865)

Browse files
Files changed (2) hide show
  1. ggml-impl.h +256 -4
  2. ggml-quants.c +0 -287
ggml-impl.h CHANGED
@@ -52,7 +52,7 @@ extern "C" {
52
  // 16-bit float
53
  // on Arm, we use __fp16
54
  // on x86, we use uint16_t
55
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
56
 
57
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
58
  //
@@ -60,8 +60,262 @@ extern "C" {
60
  //
61
  #include <arm_neon.h>
62
 
 
 
 
 
 
 
 
 
63
  typedef __fp16 ggml_fp16_internal_t;
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
66
  #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
67
 
@@ -82,8 +336,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
82
 
83
  #else
84
 
85
- typedef uint16_t ggml_fp16_internal_t;
86
-
87
  #ifdef __wasm_simd128__
88
  #include <wasm_simd128.h>
89
  #else
@@ -228,7 +480,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
228
 
229
  #endif // __F16C__
230
 
231
- #endif // __ARM_NEON
232
 
233
  // precomputed f32 table for f16 (256 KB)
234
  // defined in ggml.c, initialized in ggml_init()
 
52
  // 16-bit float
53
  // on Arm, we use __fp16
54
  // on x86, we use uint16_t
55
+ #if defined(__ARM_NEON)
56
 
57
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
58
  //
 
60
  //
61
  #include <arm_neon.h>
62
 
63
+ #ifdef _MSC_VER
64
+
65
+ typedef uint16_t ggml_fp16_internal_t;
66
+
67
+ #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
68
+
69
+ #else
70
+
71
  typedef __fp16 ggml_fp16_internal_t;
72
 
73
+ #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
74
+
75
+ #endif // _MSC_VER
76
+
77
+ #if !defined(__aarch64__)
78
+
79
+ // 32-bit ARM compatibility
80
+
81
+ // vaddvq_s16
82
+ // vpaddq_s16
83
+ // vpaddq_s32
84
+ // vaddvq_s32
85
+ // vaddvq_f32
86
+ // vmaxvq_f32
87
+ // vcvtnq_s32_f32
88
+ // vzip1_u8
89
+ // vzip2_u8
90
+
91
+ inline static int32_t vaddvq_s16(int16x8_t v) {
92
+ return
93
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
94
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
95
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
96
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
97
+ }
98
+
99
+ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
100
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
101
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
102
+ return vcombine_s16(a0, b0);
103
+ }
104
+
105
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
106
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
107
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
108
+ return vcombine_s32(a0, b0);
109
+ }
110
+
111
+ inline static int32_t vaddvq_s32(int32x4_t v) {
112
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
113
+ }
114
+
115
+ inline static float vaddvq_f32(float32x4_t v) {
116
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
117
+ }
118
+
119
+ inline static float vmaxvq_f32(float32x4_t v) {
120
+ return
121
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
122
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
123
+ }
124
+
125
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
126
+ int32x4_t res;
127
+
128
+ res[0] = roundf(vgetq_lane_f32(v, 0));
129
+ res[1] = roundf(vgetq_lane_f32(v, 1));
130
+ res[2] = roundf(vgetq_lane_f32(v, 2));
131
+ res[3] = roundf(vgetq_lane_f32(v, 3));
132
+
133
+ return res;
134
+ }
135
+
136
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
137
+ uint8x8_t res;
138
+
139
+ res[0] = a[0]; res[1] = b[0];
140
+ res[2] = a[1]; res[3] = b[1];
141
+ res[4] = a[2]; res[5] = b[2];
142
+ res[6] = a[3]; res[7] = b[3];
143
+
144
+ return res;
145
+ }
146
+
147
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
148
+ uint8x8_t res;
149
+
150
+ res[0] = a[4]; res[1] = b[4];
151
+ res[2] = a[5]; res[3] = b[5];
152
+ res[4] = a[6]; res[5] = b[6];
153
+ res[6] = a[7]; res[7] = b[7];
154
+
155
+ return res;
156
+ }
157
+
158
+ // vld1q_s16_x2
159
+ // vld1q_u8_x2
160
+ // vld1q_u8_x4
161
+ // vld1q_s8_x2
162
+ // vld1q_s8_x4
163
+ // TODO: double-check these work correctly
164
+
165
+ typedef struct ggml_int16x8x2_t {
166
+ int16x8_t val[2];
167
+ } ggml_int16x8x2_t;
168
+
169
+ inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
170
+ ggml_int16x8x2_t res;
171
+
172
+ res.val[0] = vld1q_s16(ptr + 0);
173
+ res.val[1] = vld1q_s16(ptr + 8);
174
+
175
+ return res;
176
+ }
177
+
178
+ typedef struct ggml_uint8x16x2_t {
179
+ uint8x16_t val[2];
180
+ } ggml_uint8x16x2_t;
181
+
182
+ inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
183
+ ggml_uint8x16x2_t res;
184
+
185
+ res.val[0] = vld1q_u8(ptr + 0);
186
+ res.val[1] = vld1q_u8(ptr + 16);
187
+
188
+ return res;
189
+ }
190
+
191
+ typedef struct ggml_uint8x16x4_t {
192
+ uint8x16_t val[4];
193
+ } ggml_uint8x16x4_t;
194
+
195
+ inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
196
+ ggml_uint8x16x4_t res;
197
+
198
+ res.val[0] = vld1q_u8(ptr + 0);
199
+ res.val[1] = vld1q_u8(ptr + 16);
200
+ res.val[2] = vld1q_u8(ptr + 32);
201
+ res.val[3] = vld1q_u8(ptr + 48);
202
+
203
+ return res;
204
+ }
205
+
206
+ typedef struct ggml_int8x16x2_t {
207
+ int8x16_t val[2];
208
+ } ggml_int8x16x2_t;
209
+
210
+ inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
211
+ ggml_int8x16x2_t res;
212
+
213
+ res.val[0] = vld1q_s8(ptr + 0);
214
+ res.val[1] = vld1q_s8(ptr + 16);
215
+
216
+ return res;
217
+ }
218
+
219
+ typedef struct ggml_int8x16x4_t {
220
+ int8x16_t val[4];
221
+ } ggml_int8x16x4_t;
222
+
223
+ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
224
+ ggml_int8x16x4_t res;
225
+
226
+ res.val[0] = vld1q_s8(ptr + 0);
227
+ res.val[1] = vld1q_s8(ptr + 16);
228
+ res.val[2] = vld1q_s8(ptr + 32);
229
+ res.val[3] = vld1q_s8(ptr + 48);
230
+
231
+ return res;
232
+ }
233
+
234
+ // NOTE: not tested
235
+ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
236
+ int8x16_t res;
237
+
238
+ res[ 0] = a[b[ 0]];
239
+ res[ 1] = a[b[ 1]];
240
+ res[ 2] = a[b[ 2]];
241
+ res[ 3] = a[b[ 3]];
242
+ res[ 4] = a[b[ 4]];
243
+ res[ 5] = a[b[ 5]];
244
+ res[ 6] = a[b[ 6]];
245
+ res[ 7] = a[b[ 7]];
246
+ res[ 8] = a[b[ 8]];
247
+ res[ 9] = a[b[ 9]];
248
+ res[10] = a[b[10]];
249
+ res[11] = a[b[11]];
250
+ res[12] = a[b[12]];
251
+ res[13] = a[b[13]];
252
+ res[14] = a[b[14]];
253
+ res[15] = a[b[15]];
254
+
255
+ return res;
256
+ }
257
+
258
+ // NOTE: not tested
259
+ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
260
+ uint8x16_t res;
261
+
262
+ res[ 0] = a[b[ 0]];
263
+ res[ 1] = a[b[ 1]];
264
+ res[ 2] = a[b[ 2]];
265
+ res[ 3] = a[b[ 3]];
266
+ res[ 4] = a[b[ 4]];
267
+ res[ 5] = a[b[ 5]];
268
+ res[ 6] = a[b[ 6]];
269
+ res[ 7] = a[b[ 7]];
270
+ res[ 8] = a[b[ 8]];
271
+ res[ 9] = a[b[ 9]];
272
+ res[10] = a[b[10]];
273
+ res[11] = a[b[11]];
274
+ res[12] = a[b[12]];
275
+ res[13] = a[b[13]];
276
+ res[14] = a[b[14]];
277
+ res[15] = a[b[15]];
278
+
279
+ return res;
280
+ }
281
+
282
+ #else
283
+
284
+ #define ggml_int16x8x2_t int16x8x2_t
285
+ #define ggml_uint8x16x2_t uint8x16x2_t
286
+ #define ggml_uint8x16x4_t uint8x16x4_t
287
+ #define ggml_int8x16x2_t int8x16x2_t
288
+ #define ggml_int8x16x4_t int8x16x4_t
289
+
290
+ #define ggml_vld1q_s16_x2 vld1q_s16_x2
291
+ #define ggml_vld1q_u8_x2 vld1q_u8_x2
292
+ #define ggml_vld1q_u8_x4 vld1q_u8_x4
293
+ #define ggml_vld1q_s8_x2 vld1q_s8_x2
294
+ #define ggml_vld1q_s8_x4 vld1q_s8_x4
295
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
296
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
297
+
298
+ #endif // !defined(__aarch64__)
299
+
300
+ #if !defined(__ARM_FEATURE_DOTPROD)
301
+
302
+ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
303
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
304
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
305
+
306
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
307
+ }
308
+
309
+ #else
310
+
311
+ #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
312
+
313
+ #endif // !defined(__ARM_FEATURE_DOTPROD)
314
+
315
+ #endif // defined(__ARM_NEON)
316
+
317
+ #if defined(__ARM_NEON) && !defined(__MSC_VER)
318
+
319
  #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
320
  #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
321
 
 
336
 
337
  #else
338
 
 
 
339
  #ifdef __wasm_simd128__
340
  #include <wasm_simd128.h>
341
  #else
 
480
 
481
  #endif // __F16C__
482
 
483
+ #endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
484
 
485
  // precomputed f32 table for f16 (256 KB)
486
  // defined in ggml.c, initialized in ggml_init()
ggml-quants.c CHANGED
@@ -20,41 +20,6 @@
20
  #pragma warning(disable: 4244 4267)
21
  #endif
22
 
23
- #ifdef __ARM_NEON
24
-
25
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
26
- //
27
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
28
- //
29
- #include <arm_neon.h>
30
-
31
- #else
32
-
33
- #ifdef __wasm_simd128__
34
- #include <wasm_simd128.h>
35
- #else
36
- #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
37
- #include <altivec.h>
38
- #undef bool
39
- #define bool _Bool
40
- #else
41
- #if defined(_MSC_VER) || defined(__MINGW32__)
42
- #include <intrin.h>
43
- #else
44
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
45
- #if !defined(__riscv)
46
- #include <immintrin.h>
47
- #endif
48
- #endif
49
- #endif
50
- #endif
51
- #endif
52
- #endif
53
-
54
- #ifdef __riscv_v_intrinsic
55
- #include <riscv_vector.h>
56
- #endif
57
-
58
  #undef MIN
59
  #undef MAX
60
 
@@ -282,258 +247,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
282
  #endif // __AVX__ || __AVX2__ || __AVX512F__
283
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
284
 
285
- #if defined(__ARM_NEON)
286
-
287
- #ifdef _MSC_VER
288
-
289
- #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
290
-
291
- #else
292
-
293
- #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
294
-
295
- #endif
296
-
297
- #if !defined(__aarch64__)
298
-
299
- // 64-bit compatibility
300
-
301
- // vaddvq_s16
302
- // vpaddq_s16
303
- // vpaddq_s32
304
- // vaddvq_s32
305
- // vaddvq_f32
306
- // vmaxvq_f32
307
- // vcvtnq_s32_f32
308
- // vzip1_u8
309
- // vzip2_u8
310
-
311
- inline static int32_t vaddvq_s16(int16x8_t v) {
312
- return
313
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
314
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
315
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
316
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
317
- }
318
-
319
- inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
320
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
321
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
322
- return vcombine_s16(a0, b0);
323
- }
324
-
325
- inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
326
- int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
327
- int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
328
- return vcombine_s32(a0, b0);
329
- }
330
-
331
- inline static int32_t vaddvq_s32(int32x4_t v) {
332
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
333
- }
334
-
335
- inline static float vaddvq_f32(float32x4_t v) {
336
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
337
- }
338
-
339
- inline static float vmaxvq_f32(float32x4_t v) {
340
- return
341
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
342
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
343
- }
344
-
345
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
346
- int32x4_t res;
347
-
348
- res[0] = roundf(vgetq_lane_f32(v, 0));
349
- res[1] = roundf(vgetq_lane_f32(v, 1));
350
- res[2] = roundf(vgetq_lane_f32(v, 2));
351
- res[3] = roundf(vgetq_lane_f32(v, 3));
352
-
353
- return res;
354
- }
355
-
356
- inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
357
- uint8x8_t res;
358
-
359
- res[0] = a[0]; res[1] = b[0];
360
- res[2] = a[1]; res[3] = b[1];
361
- res[4] = a[2]; res[5] = b[2];
362
- res[6] = a[3]; res[7] = b[3];
363
-
364
- return res;
365
- }
366
-
367
- inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
368
- uint8x8_t res;
369
-
370
- res[0] = a[4]; res[1] = b[4];
371
- res[2] = a[5]; res[3] = b[5];
372
- res[4] = a[6]; res[5] = b[6];
373
- res[6] = a[7]; res[7] = b[7];
374
-
375
- return res;
376
- }
377
-
378
- // vld1q_s16_x2
379
- // vld1q_u8_x2
380
- // vld1q_u8_x4
381
- // vld1q_s8_x2
382
- // vld1q_s8_x4
383
- // TODO: double-check these work correctly
384
-
385
- typedef struct ggml_int16x8x2_t {
386
- int16x8_t val[2];
387
- } ggml_int16x8x2_t;
388
-
389
- inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
390
- ggml_int16x8x2_t res;
391
-
392
- res.val[0] = vld1q_s16(ptr + 0);
393
- res.val[1] = vld1q_s16(ptr + 8);
394
-
395
- return res;
396
- }
397
-
398
- typedef struct ggml_uint8x16x2_t {
399
- uint8x16_t val[2];
400
- } ggml_uint8x16x2_t;
401
-
402
- inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
403
- ggml_uint8x16x2_t res;
404
-
405
- res.val[0] = vld1q_u8(ptr + 0);
406
- res.val[1] = vld1q_u8(ptr + 16);
407
-
408
- return res;
409
- }
410
-
411
- typedef struct ggml_uint8x16x4_t {
412
- uint8x16_t val[4];
413
- } ggml_uint8x16x4_t;
414
-
415
- inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
416
- ggml_uint8x16x4_t res;
417
-
418
- res.val[0] = vld1q_u8(ptr + 0);
419
- res.val[1] = vld1q_u8(ptr + 16);
420
- res.val[2] = vld1q_u8(ptr + 32);
421
- res.val[3] = vld1q_u8(ptr + 48);
422
-
423
- return res;
424
- }
425
-
426
- typedef struct ggml_int8x16x2_t {
427
- int8x16_t val[2];
428
- } ggml_int8x16x2_t;
429
-
430
- inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
431
- ggml_int8x16x2_t res;
432
-
433
- res.val[0] = vld1q_s8(ptr + 0);
434
- res.val[1] = vld1q_s8(ptr + 16);
435
-
436
- return res;
437
- }
438
-
439
- typedef struct ggml_int8x16x4_t {
440
- int8x16_t val[4];
441
- } ggml_int8x16x4_t;
442
-
443
- inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
444
- ggml_int8x16x4_t res;
445
-
446
- res.val[0] = vld1q_s8(ptr + 0);
447
- res.val[1] = vld1q_s8(ptr + 16);
448
- res.val[2] = vld1q_s8(ptr + 32);
449
- res.val[3] = vld1q_s8(ptr + 48);
450
-
451
- return res;
452
- }
453
-
454
- // NOTE: not tested
455
- inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
456
- int8x16_t res;
457
-
458
- res[ 0] = a[b[ 0]];
459
- res[ 1] = a[b[ 1]];
460
- res[ 2] = a[b[ 2]];
461
- res[ 3] = a[b[ 3]];
462
- res[ 4] = a[b[ 4]];
463
- res[ 5] = a[b[ 5]];
464
- res[ 6] = a[b[ 6]];
465
- res[ 7] = a[b[ 7]];
466
- res[ 8] = a[b[ 8]];
467
- res[ 9] = a[b[ 9]];
468
- res[10] = a[b[10]];
469
- res[11] = a[b[11]];
470
- res[12] = a[b[12]];
471
- res[13] = a[b[13]];
472
- res[14] = a[b[14]];
473
- res[15] = a[b[15]];
474
-
475
- return res;
476
- }
477
-
478
- // NOTE: not tested
479
- inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
480
- uint8x16_t res;
481
-
482
- res[ 0] = a[b[ 0]];
483
- res[ 1] = a[b[ 1]];
484
- res[ 2] = a[b[ 2]];
485
- res[ 3] = a[b[ 3]];
486
- res[ 4] = a[b[ 4]];
487
- res[ 5] = a[b[ 5]];
488
- res[ 6] = a[b[ 6]];
489
- res[ 7] = a[b[ 7]];
490
- res[ 8] = a[b[ 8]];
491
- res[ 9] = a[b[ 9]];
492
- res[10] = a[b[10]];
493
- res[11] = a[b[11]];
494
- res[12] = a[b[12]];
495
- res[13] = a[b[13]];
496
- res[14] = a[b[14]];
497
- res[15] = a[b[15]];
498
-
499
- return res;
500
- }
501
-
502
- #else
503
-
504
- #define ggml_int16x8x2_t int16x8x2_t
505
- #define ggml_uint8x16x2_t uint8x16x2_t
506
- #define ggml_uint8x16x4_t uint8x16x4_t
507
- #define ggml_int8x16x2_t int8x16x2_t
508
- #define ggml_int8x16x4_t int8x16x4_t
509
-
510
- #define ggml_vld1q_s16_x2 vld1q_s16_x2
511
- #define ggml_vld1q_u8_x2 vld1q_u8_x2
512
- #define ggml_vld1q_u8_x4 vld1q_u8_x4
513
- #define ggml_vld1q_s8_x2 vld1q_s8_x2
514
- #define ggml_vld1q_s8_x4 vld1q_s8_x4
515
- #define ggml_vqtbl1q_s8 vqtbl1q_s8
516
- #define ggml_vqtbl1q_u8 vqtbl1q_u8
517
-
518
- #endif
519
-
520
- #if !defined(__ARM_FEATURE_DOTPROD)
521
-
522
- inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
523
- const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
524
- const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
525
-
526
- return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
527
- }
528
-
529
- #else
530
-
531
- #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
532
-
533
- #endif
534
-
535
- #endif
536
-
537
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
538
  #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
539
  #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
 
20
  #pragma warning(disable: 4244 4267)
21
  #endif
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  #undef MIN
24
  #undef MAX
25
 
 
247
  #endif // __AVX__ || __AVX2__ || __AVX512F__
248
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
251
  #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
252
  #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)