ggerganov commited on
Commit
3bd52ce
·
unverified ·
1 Parent(s): 70332a0

ggml : sync latest changes from ggml and llama.cpp

Browse files
Files changed (3) hide show
  1. Makefile +2 -2
  2. ggml.c +173 -89
  3. ggml.h +11 -9
Makefile CHANGED
@@ -157,7 +157,7 @@ endif
157
  ifneq ($(filter armv7%,$(UNAME_M)),)
158
  # 32-bit ARM, for example on Armbian or possibly raspbian
159
  CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
160
-
161
  # 64-bit ARM, use these (TODO: auto-detect 64-bit)
162
  # CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
163
  endif
@@ -190,7 +190,7 @@ default: main bench
190
  ggml.o: ggml.c ggml.h
191
  $(CC) $(CFLAGS) -c ggml.c -o ggml.o
192
 
193
- whisper.o: whisper.cpp whisper.h
194
  $(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
195
 
196
  libwhisper.a: ggml.o whisper.o
 
157
  ifneq ($(filter armv7%,$(UNAME_M)),)
158
  # 32-bit ARM, for example on Armbian or possibly raspbian
159
  CFLAGS += -mfpu=neon -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
160
+
161
  # 64-bit ARM, use these (TODO: auto-detect 64-bit)
162
  # CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
163
  endif
 
190
  ggml.o: ggml.c ggml.h
191
  $(CC) $(CFLAGS) -c ggml.c -o ggml.o
192
 
193
+ whisper.o: whisper.cpp whisper.h ggml.h
194
  $(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
195
 
196
  libwhisper.a: ggml.o whisper.o
ggml.c CHANGED
@@ -1,4 +1,4 @@
1
- // Defines CLOCK_MONOTONIC and asprintf on Linux
2
  #define _GNU_SOURCE
3
 
4
  #include "ggml.h"
@@ -26,14 +26,9 @@
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
27
  #endif
28
 
29
- #if defined _MSC_VER || defined(__MINGW32__)
30
 
31
- #if !defined(__MINGW32__)
32
- #include <Windows.h>
33
- #else
34
- // ref: https://github.com/ggerganov/whisper.cpp/issues/168
35
  #include <windows.h>
36
- #endif
37
 
38
  typedef volatile LONG atomic_int;
39
  typedef atomic_int atomic_bool;
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
55
 
56
  typedef DWORD thread_ret_t;
57
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
 
58
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
59
  if (handle == NULL)
60
  {
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
66
  }
67
 
68
  static int pthread_join(pthread_t thread, void* unused) {
 
69
  return (int) WaitForSingleObject(thread, INFINITE);
70
  }
71
 
@@ -117,6 +114,14 @@ typedef void* thread_ret_t;
117
  #define GGML_MEM_ALIGN 16
118
  #endif
119
 
 
 
 
 
 
 
 
 
120
  #define UNUSED(x) (void)(x)
121
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
122
 
@@ -231,12 +236,12 @@ static inline float fp32_from_bits(uint32_t w) {
231
  }
232
 
233
  static inline uint32_t fp32_to_bits(float f) {
234
- union {
235
- float as_value;
236
- uint32_t as_bits;
237
- } fp32;
238
- fp32.as_value = f;
239
- return fp32.as_bits;
240
  }
241
 
242
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -486,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
486
  }
487
  #endif
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  // method 5
490
  // blocks of QK elements
491
  // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1213,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1213
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1214
  #define GGML_F32x4_ADD vaddq_f32
1215
  #define GGML_F32x4_MUL vmulq_f32
1216
- #if defined(__ARM_FEATURE_QRDMX)
1217
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1218
- #else
1219
- #define GGML_F32x4_REDUCE_ONE(x) \
1220
- (vgetq_lane_f32(x, 0) + \
1221
- vgetq_lane_f32(x, 1) + \
1222
- vgetq_lane_f32(x, 2) + \
1223
- vgetq_lane_f32(x, 3))
1224
- #endif
1225
  #define GGML_F32x4_REDUCE(res, x) \
1226
  { \
1227
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1844,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1844
  // 4-bit -> 8-bit
1845
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1846
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1847
-
1848
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1849
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1850
 
1851
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1852
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1853
-
1854
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1855
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1856
 
1857
  // sub 8
1858
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1859
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1860
-
1861
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1862
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1863
 
1864
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1865
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1866
-
1867
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1868
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1869
 
1870
  #if defined(__ARM_FEATURE_DOTPROD)
1871
- // dot product into int16x8_t
1872
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1873
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1874
 
1875
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1876
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1877
 
1878
- // scalar
1879
- #if defined(__ARM_FEATURE_QRDMX)
1880
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1881
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
1882
  #else
1883
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1884
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1885
- #endif
1886
- #else
1887
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1888
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1889
-
1890
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1891
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1892
 
1893
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1894
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1895
-
1896
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1897
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1898
 
@@ -1905,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1905
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1906
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1907
 
1908
- // scalar
1909
- #if defined(__ARM_FEATURE_QRDMX)
1910
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1911
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1912
- #else
1913
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1914
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1915
- #endif
1916
  #endif
1917
  }
1918
 
@@ -2155,18 +2205,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2155
  const uint8_t * restrict p0 = x[i].qs;
2156
  const uint8_t * restrict p1 = y[i].qs;
2157
 
 
2158
  for (int j = 0; j < QK/2; j++) {
2159
  const uint8_t v0 = p0[j];
2160
  const uint8_t v1 = p1[j];
2161
 
2162
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
2163
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
2164
 
2165
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
2166
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
2167
 
2168
- sumf += f0*f2 + f1*f3;
2169
  }
 
2170
  }
2171
  #endif
2172
 
@@ -2258,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2258
  float sum10 = 0.0f;
2259
  float sum11 = 0.0f;
2260
 
2261
- for (int i = 0; i < nb; ++i) {
2262
  const block_q4_1 * restrict x0 = &x[i + 0];
2263
  const block_q4_1 * restrict y0 = &y[i + 0];
 
 
2264
 
2265
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2266
 
2267
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2268
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
 
 
2269
 
2270
- // and with 0xf
2271
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2272
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2273
-
2274
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2275
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2276
 
2277
- // dot product into uint16x8_t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2278
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2279
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2280
-
2281
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2282
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2283
 
2284
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2285
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
 
 
2286
 
2287
- sum00 += x0->m*y0->m;
2288
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2289
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2290
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
 
 
 
 
 
 
 
 
2291
  }
2292
 
2293
  sumf = QK*sum00 + sum01 + sum10 + sum11;
@@ -2563,29 +2650,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2563
  //
2564
 
2565
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2566
- QK,
2567
- QK,
2568
- 1,
2569
- 1,
2570
- 1,
2571
- 1,
2572
- 1,
2573
  };
2574
-
2575
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2576
 
2577
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2578
- sizeof(block_q4_0),
2579
- sizeof(block_q4_1),
2580
- sizeof(int8_t ),
2581
- sizeof(int16_t),
2582
- sizeof(int32_t),
2583
- sizeof(ggml_fp16_t),
2584
- sizeof(float ),
2585
  };
2586
-
2587
- // don't forget to update the array above when adding new types
2588
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2589
 
2590
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2591
  "NONE",
@@ -2972,7 +3056,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2972
 
2973
  *ctx = (struct ggml_context) {
2974
  /*.mem_size =*/ params.mem_size,
2975
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
2976
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
2977
  /*.no_alloc =*/ params.no_alloc,
2978
  /*.n_objects =*/ 0,
@@ -3007,7 +3091,7 @@ void ggml_free(struct ggml_context * ctx) {
3007
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3008
 
3009
  if (ctx->mem_buffer_owned) {
3010
- free(ctx->mem_buffer);
3011
  }
3012
 
3013
  found = true;
@@ -6441,7 +6525,7 @@ static void ggml_compute_forward_mul_mat_f32(
6441
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6442
  ne11, ne01, ne10,
6443
  1.0f, y, ne10,
6444
- x, ne10,
6445
  0.0f, d, ne01);
6446
  }
6447
  }
@@ -6613,7 +6697,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6613
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6614
  ne11, ne01, ne10,
6615
  1.0f, y, ne10,
6616
- x, ne10,
6617
  0.0f, d, ne01);
6618
  }
6619
  }
@@ -6826,7 +6910,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6826
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6827
  ne11, ne01, ne10,
6828
  1.0f, y, ne10,
6829
- x, ne10,
6830
  0.0f, d, ne01);
6831
  }
6832
  }
@@ -9279,7 +9363,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
9279
  struct ggml_cgraph result = {
9280
  /*.n_nodes =*/ 0,
9281
  /*.n_leafs =*/ 0,
9282
- /*.n_threads =*/ 0,
9283
  /*.work_size =*/ 0,
9284
  /*.work =*/ NULL,
9285
  /*.nodes =*/ { NULL },
@@ -9899,8 +9983,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9899
 
9900
  GGML_PRINT("=== GRAPH ===\n");
9901
 
9902
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9903
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
9904
 
9905
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9906
  for (int i = 0; i < cgraph->n_nodes; i++) {
 
1
+ // Defines CLOCK_MONOTONIC on Linux
2
  #define _GNU_SOURCE
3
 
4
  #include "ggml.h"
 
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
27
  #endif
28
 
29
+ #if defined(_WIN32)
30
 
 
 
 
 
31
  #include <windows.h>
 
32
 
33
  typedef volatile LONG atomic_int;
34
  typedef atomic_int atomic_bool;
 
50
 
51
  typedef DWORD thread_ret_t;
52
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
53
+ (void) unused;
54
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
55
  if (handle == NULL)
56
  {
 
62
  }
63
 
64
  static int pthread_join(pthread_t thread, void* unused) {
65
+ (void) unused;
66
  return (int) WaitForSingleObject(thread, INFINITE);
67
  }
68
 
 
114
  #define GGML_MEM_ALIGN 16
115
  #endif
116
 
117
+ #if defined(_MSC_VER) || defined(__MINGW32__)
118
+ #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
+ #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
+ #else
121
+ #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
122
+ #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
+ #endif
124
+
125
  #define UNUSED(x) (void)(x)
126
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
127
 
 
236
  }
237
 
238
  static inline uint32_t fp32_to_bits(float f) {
239
+ union {
240
+ float as_value;
241
+ uint32_t as_bits;
242
+ } fp32;
243
+ fp32.as_value = f;
244
+ return fp32.as_bits;
245
  }
246
 
247
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
 
491
  }
492
  #endif
493
 
494
+ #if __ARM_NEON
495
+
496
+ #if !defined(__aarch64__)
497
+
498
+ inline static uint16_t vaddvq_u8(uint8x16_t v) {
499
+ return
500
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
501
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
502
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
503
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
504
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
505
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
506
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
507
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508
+ }
509
+
510
+ inline static int32_t vaddvq_s16(int16x8_t v) {
511
+ return
512
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
513
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
514
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
515
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
516
+ }
517
+
518
+ inline static uint32_t vaddvq_u16(uint16x8_t v) {
519
+ return
520
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
521
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
522
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
523
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
524
+ }
525
+
526
+ inline static int32_t vaddvq_s32(int32x4_t v) {
527
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
528
+ }
529
+
530
+ inline static float vaddvq_f32(float32x4_t v) {
531
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
+ }
533
+
534
+ inline float vminvq_f32(float32x4_t v) {
535
+ return
536
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
+ }
539
+
540
+ inline float vmaxvq_f32(float32x4_t v) {
541
+ return
542
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
+ }
545
+
546
+ inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
+ return vget_low_s8(vcombine_s8(a, b));
548
+ }
549
+
550
+ inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
+ return vget_high_s8(vcombine_s8(a, b));
552
+ }
553
+
554
+ inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
+ return vget_low_u8(vcombine_u8(a, b));
556
+ }
557
+
558
+ inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
+ return vget_high_u8(vcombine_u8(a, b));
560
+ }
561
+
562
+ #endif
563
+ #endif
564
+
565
  // method 5
566
  // blocks of QK elements
567
  // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
 
1289
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1290
  #define GGML_F32x4_ADD vaddq_f32
1291
  #define GGML_F32x4_MUL vmulq_f32
1292
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
 
 
 
 
 
 
 
 
1293
  #define GGML_F32x4_REDUCE(res, x) \
1294
  { \
1295
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
 
1912
  // 4-bit -> 8-bit
1913
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1914
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
 
1915
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1916
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1917
 
1918
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1919
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
 
1920
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1921
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1922
 
1923
  // sub 8
1924
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1925
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
 
1926
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1927
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1928
 
1929
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1930
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
 
1931
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1932
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1933
 
1934
  #if defined(__ARM_FEATURE_DOTPROD)
1935
+ // dot product into int32x4_t
1936
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1937
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1938
 
1939
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1940
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1941
 
1942
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
 
 
1944
  #else
1945
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
 
 
 
 
1946
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
 
1947
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1948
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1949
 
1950
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1951
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
 
1952
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1953
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1954
 
 
1961
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1962
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1963
 
1964
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
 
 
 
 
 
 
1966
  #endif
1967
  }
1968
 
 
2205
  const uint8_t * restrict p0 = x[i].qs;
2206
  const uint8_t * restrict p1 = y[i].qs;
2207
 
2208
+ int sumi = 0;
2209
  for (int j = 0; j < QK/2; j++) {
2210
  const uint8_t v0 = p0[j];
2211
  const uint8_t v1 = p1[j];
2212
 
2213
+ const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
+ const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2215
 
2216
+ const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
+ const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2218
 
2219
+ sumi += i0*i2 + i1*i3;
2220
  }
2221
+ sumf += d0 * d1 * sumi;
2222
  }
2223
  #endif
2224
 
 
2310
  float sum10 = 0.0f;
2311
  float sum11 = 0.0f;
2312
 
2313
+ for (int i = 0; i < nb; i += 2) {
2314
  const block_q4_1 * restrict x0 = &x[i + 0];
2315
  const block_q4_1 * restrict y0 = &y[i + 0];
2316
+ const block_q4_1 * restrict x1 = &x[i + 1];
2317
+ const block_q4_1 * restrict y1 = &y[i + 1];
2318
 
2319
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2320
 
2321
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2322
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2325
 
2326
+ // 4-bit -> 8-bit
2327
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2328
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
 
2329
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2330
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2331
 
2332
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336
+
2337
+ sum00 += x0->m*y0->m;
2338
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2340
+
2341
+ sum00 += x1->m*y1->m;
2342
+ sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
+ sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2344
+
2345
+ #if defined(__ARM_FEATURE_DOTPROD)
2346
+ // dot product into int32x4_t
2347
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
2348
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
2349
+
2350
+ p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
2351
+ p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
2352
+
2353
+ sum11 += x0->d*y0->d*vaddvq_s32(p_0);
2354
+ sum11 += x1->d*y1->d*vaddvq_s32(p_1);
2355
+ #else
2356
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2357
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
 
2358
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2359
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2360
 
2361
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2365
 
2366
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2368
+
2369
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2371
+
2372
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2374
+
2375
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377
+ #endif
2378
  }
2379
 
2380
  sumf = QK*sum00 + sum01 + sum10 + sum11;
 
2650
  //
2651
 
2652
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2653
+ [GGML_TYPE_F32] = 1,
2654
+ [GGML_TYPE_F16] = 1,
2655
+ [GGML_TYPE_Q4_0] = QK,
2656
+ [GGML_TYPE_Q4_1] = QK,
2657
+ [GGML_TYPE_I8] = 1,
2658
+ [GGML_TYPE_I16] = 1,
2659
+ [GGML_TYPE_I32] = 1,
2660
  };
2661
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
 
2662
 
2663
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2664
+ [GGML_TYPE_F32] = sizeof(float),
2665
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
2668
+ [GGML_TYPE_I8] = sizeof(int8_t),
2669
+ [GGML_TYPE_I16] = sizeof(int16_t),
2670
+ [GGML_TYPE_I32] = sizeof(int32_t),
2671
  };
2672
+ static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
 
 
2673
 
2674
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2675
  "NONE",
 
3056
 
3057
  *ctx = (struct ggml_context) {
3058
  /*.mem_size =*/ params.mem_size,
3059
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(params.mem_size),
3060
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3061
  /*.no_alloc =*/ params.no_alloc,
3062
  /*.n_objects =*/ 0,
 
3091
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3092
 
3093
  if (ctx->mem_buffer_owned) {
3094
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
3095
  }
3096
 
3097
  found = true;
 
6525
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6526
  ne11, ne01, ne10,
6527
  1.0f, y, ne10,
6528
+ x, ne00,
6529
  0.0f, d, ne01);
6530
  }
6531
  }
 
6697
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6698
  ne11, ne01, ne10,
6699
  1.0f, y, ne10,
6700
+ x, ne00,
6701
  0.0f, d, ne01);
6702
  }
6703
  }
 
6910
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6911
  ne11, ne01, ne10,
6912
  1.0f, y, ne10,
6913
+ x, ne00,
6914
  0.0f, d, ne01);
6915
  }
6916
  }
 
9363
  struct ggml_cgraph result = {
9364
  /*.n_nodes =*/ 0,
9365
  /*.n_leafs =*/ 0,
9366
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
9367
  /*.work_size =*/ 0,
9368
  /*.work =*/ NULL,
9369
  /*.nodes =*/ { NULL },
 
9983
 
9984
  GGML_PRINT("=== GRAPH ===\n");
9985
 
9986
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9987
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
9988
 
9989
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9990
  for (int i = 0; i < cgraph->n_nodes; i++) {
ggml.h CHANGED
@@ -177,11 +177,12 @@ extern "C" {
177
  #include <stddef.h>
178
  #include <stdbool.h>
179
 
180
- #define GGML_MAX_DIMS 4
181
- #define GGML_MAX_NODES 4096
182
- #define GGML_MAX_PARAMS 16
183
- #define GGML_MAX_CONTEXTS 64
184
- #define GGML_MAX_OPT 4
 
185
 
186
  #ifdef __ARM_NEON
187
  // we use the built-in 16-bit float type
@@ -198,13 +199,14 @@ struct ggml_object;
198
  struct ggml_context;
199
 
200
  enum ggml_type {
201
- GGML_TYPE_Q4_0,
202
- GGML_TYPE_Q4_1,
 
 
 
203
  GGML_TYPE_I8,
204
  GGML_TYPE_I16,
205
  GGML_TYPE_I32,
206
- GGML_TYPE_F16,
207
- GGML_TYPE_F32,
208
  GGML_TYPE_COUNT,
209
  };
210
 
 
177
  #include <stddef.h>
178
  #include <stdbool.h>
179
 
180
+ #define GGML_MAX_DIMS 4
181
+ #define GGML_MAX_NODES 4096
182
+ #define GGML_MAX_PARAMS 16
183
+ #define GGML_MAX_CONTEXTS 64
184
+ #define GGML_MAX_OPT 4
185
+ #define GGML_DEFAULT_N_THREADS 4
186
 
187
  #ifdef __ARM_NEON
188
  // we use the built-in 16-bit float type
 
199
  struct ggml_context;
200
 
201
  enum ggml_type {
202
+ // explicitly numbered values are used in llama.cpp files
203
+ GGML_TYPE_F32 = 0,
204
+ GGML_TYPE_F16 = 1,
205
+ GGML_TYPE_Q4_0 = 2,
206
+ GGML_TYPE_Q4_1 = 3,
207
  GGML_TYPE_I8,
208
  GGML_TYPE_I16,
209
  GGML_TYPE_I32,
 
 
210
  GGML_TYPE_COUNT,
211
  };
212