ggerganov commited on
Commit
88deeba
·
unverified ·
1 Parent(s): a88e806

sync : ggml (HBM + Metal + style) (#1264)

Browse files
Files changed (3) hide show
  1. ggml-metal.m +1 -1
  2. ggml-metal.metal +17 -32
  3. ggml.c +24 -6
ggml-metal.m CHANGED
@@ -1141,7 +1141,7 @@ void ggml_metal_graph_compute(
1141
  [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1142
  [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1143
 
1144
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1145
  } break;
1146
  case GGML_OP_DUP:
1147
  case GGML_OP_CPY:
 
1141
  [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1142
  [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1143
 
1144
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1145
  } break;
1146
  case GGML_OP_DUP:
1147
  case GGML_OP_CPY:
ggml-metal.metal CHANGED
@@ -220,14 +220,10 @@ kernel void kernel_norm(
220
  }
221
  threadgroup_barrier(mem_flags::mem_threadgroup);
222
  }
223
- //// broadcast
224
- //if (tpitg == 0) {
225
- // sum[0] /= ne00;
226
- //}
227
- //threadgroup_barrier(mem_flags::mem_threadgroup);
228
- const float mean = sum[0];
229
 
230
  // recenter and VARIANCE
 
231
  device float * y = dst + tgpig*ne00;
232
  sum[tpitg] = 0.0f;
233
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
@@ -235,12 +231,6 @@ kernel void kernel_norm(
235
  sum[tpitg] += y[i00] * y[i00];
236
  }
237
 
238
- //// VARIANCE
239
- //// parallel sum
240
- //sum[tpitg] = 0.0f;
241
- //for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
242
- // sum[tpitg] += y[i00] * y[i00];
243
- //}
244
  // reduce
245
  threadgroup_barrier(mem_flags::mem_threadgroup);
246
  for (uint i = ntg/2; i > 0; i /= 2) {
@@ -249,12 +239,7 @@ kernel void kernel_norm(
249
  }
250
  threadgroup_barrier(mem_flags::mem_threadgroup);
251
  }
252
- //// broadcast
253
- //if (tpitg == 0) {
254
- // sum[0] /= ne00;
255
- //}
256
- //threadgroup_barrier(mem_flags::mem_threadgroup);
257
- const float variance = sum[0];
258
 
259
  const float scale = 1.0f/sqrt(variance + eps);
260
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
@@ -262,7 +247,6 @@ kernel void kernel_norm(
262
  }
263
  }
264
 
265
-
266
  kernel void kernel_rms_norm(
267
  device const void * src0,
268
  device float * dst,
@@ -630,7 +614,6 @@ kernel void kernel_mul_mat_f16_f32(
630
  }
631
  }
632
  }
633
-
634
  }
635
 
636
  kernel void kernel_alibi_f32(
@@ -699,25 +682,27 @@ kernel void kernel_rope(
699
  constant int & mode,
700
  constant float & freq_base,
701
  constant float & freq_scale,
702
- uint3 tpig[[thread_position_in_grid]]) {
703
- const int64_t i3 = tpig[2];
704
- const int64_t i2 = tpig[1];
705
- const int64_t i1 = tpig[0];
 
 
706
 
707
  const bool is_neox = mode & 2;
708
- const float theta_scale = pow(freq_base, -2.0f/n_dims);
709
 
710
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
711
 
712
- float theta = freq_scale * (float)p;
 
713
 
714
  if (!is_neox) {
715
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 
 
716
  const float cos_theta = cos(theta);
717
  const float sin_theta = sin(theta);
718
 
719
- theta *= theta_scale;
720
-
721
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
722
  device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
723
 
@@ -729,12 +714,12 @@ kernel void kernel_rope(
729
  }
730
  } else {
731
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
732
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
 
 
733
  const float cos_theta = cos(theta);
734
  const float sin_theta = sin(theta);
735
 
736
- theta *= theta_scale;
737
-
738
  const int64_t i0 = ib*n_dims + ic/2;
739
 
740
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 
220
  }
221
  threadgroup_barrier(mem_flags::mem_threadgroup);
222
  }
223
+ const float mean = sum[0] / ne00;
 
 
 
 
 
224
 
225
  // recenter and VARIANCE
226
+ threadgroup_barrier(mem_flags::mem_threadgroup);
227
  device float * y = dst + tgpig*ne00;
228
  sum[tpitg] = 0.0f;
229
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
 
231
  sum[tpitg] += y[i00] * y[i00];
232
  }
233
 
 
 
 
 
 
 
234
  // reduce
235
  threadgroup_barrier(mem_flags::mem_threadgroup);
236
  for (uint i = ntg/2; i > 0; i /= 2) {
 
239
  }
240
  threadgroup_barrier(mem_flags::mem_threadgroup);
241
  }
242
+ const float variance = sum[0] / ne00;
 
 
 
 
 
243
 
244
  const float scale = 1.0f/sqrt(variance + eps);
245
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
 
247
  }
248
  }
249
 
 
250
  kernel void kernel_rms_norm(
251
  device const void * src0,
252
  device float * dst,
 
614
  }
615
  }
616
  }
 
617
  }
618
 
619
  kernel void kernel_alibi_f32(
 
682
  constant int & mode,
683
  constant float & freq_base,
684
  constant float & freq_scale,
685
+ uint tiitg[[thread_index_in_threadgroup]],
686
+ uint3 tptg[[threads_per_threadgroup]],
687
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
688
+ const int64_t i3 = tgpig[2];
689
+ const int64_t i2 = tgpig[1];
690
+ const int64_t i1 = tgpig[0];
691
 
692
  const bool is_neox = mode & 2;
 
693
 
694
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
695
 
696
+ const float theta_0 = freq_scale * (float)p;
697
+ const float inv_ndims = -1.f/n_dims;
698
 
699
  if (!is_neox) {
700
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
701
+
702
+ const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
703
  const float cos_theta = cos(theta);
704
  const float sin_theta = sin(theta);
705
 
 
 
706
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
707
  device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
708
 
 
714
  }
715
  } else {
716
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
717
+ for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
718
+
719
+ const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
720
  const float cos_theta = cos(theta);
721
  const float sin_theta = sin(theta);
722
 
 
 
723
  const int64_t i0 = ib*n_dims + ic/2;
724
 
725
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml.c CHANGED
@@ -106,6 +106,9 @@ typedef void * thread_ret_t;
106
  #include <sys/stat.h>
107
  #include <unistd.h>
108
 
 
 
 
109
  #endif
110
 
111
  // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
@@ -195,8 +198,14 @@ typedef void * thread_ret_t;
195
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
196
  #else
197
  inline static void * ggml_aligned_malloc(size_t size) {
 
 
 
 
198
  void * aligned_memory = NULL;
199
- #ifdef GGML_USE_METAL
 
 
200
  int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
201
  #else
202
  int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
@@ -218,8 +227,12 @@ inline static void * ggml_aligned_malloc(size_t size) {
218
  return aligned_memory;
219
  }
220
  #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
 
 
 
221
  #define GGML_ALIGNED_FREE(ptr) free(ptr)
222
  #endif
 
223
 
224
  #define UNUSED GGML_UNUSED
225
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
@@ -4571,6 +4584,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
4571
  return NULL;
4572
  }
4573
 
 
 
 
 
 
4574
  const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
4575
 
4576
  *ctx = (struct ggml_context) {
@@ -4773,7 +4791,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
4773
 
4774
  size_t obj_alloc_size = 0;
4775
 
4776
- if (view_src == NULL && ctx->no_alloc == false) {
4777
  if (ctx->scratch.data != NULL) {
4778
  // allocate tensor data in the scratch buffer
4779
  if (ctx->scratch.offs + data_size > ctx->scratch.size) {
@@ -5474,7 +5492,7 @@ static struct ggml_tensor * ggml_mul_impl(
5474
  }
5475
 
5476
  if (inplace) {
5477
- GGML_ASSERT(is_node == false);
5478
  }
5479
 
5480
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -5517,7 +5535,7 @@ static struct ggml_tensor * ggml_div_impl(
5517
  }
5518
 
5519
  if (inplace) {
5520
- GGML_ASSERT(is_node == false);
5521
  }
5522
 
5523
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -19961,7 +19979,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
19961
 
19962
  struct ggml_tensor * data = NULL;
19963
 
19964
- if (params.no_alloc == false) {
19965
  data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
19966
 
19967
  ok = ok && data != NULL;
@@ -20002,7 +20020,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20002
  }
20003
 
20004
  // point the data member to the appropriate location in the binary blob using the tensor infos
20005
- if (params.no_alloc == false) {
20006
  //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
20007
  cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
20008
  }
 
106
  #include <sys/stat.h>
107
  #include <unistd.h>
108
 
109
+ #endif
110
+ #ifdef GGML_USE_CPU_HBM
111
+ #include <hbwmalloc.h>
112
  #endif
113
 
114
  // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
 
198
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
199
  #else
200
  inline static void * ggml_aligned_malloc(size_t size) {
201
+ if (size == 0) {
202
+ GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
203
+ return NULL;
204
+ }
205
  void * aligned_memory = NULL;
206
+ #ifdef GGML_USE_CPU_HBM
207
+ int result = hbw_posix_memalign(&aligned_memory, 16, size);
208
+ #elif GGML_USE_METAL
209
  int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
210
  #else
211
  int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
 
227
  return aligned_memory;
228
  }
229
  #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
230
+ #ifdef GGML_USE_CPU_HBM
231
+ #define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
232
+ #else
233
  #define GGML_ALIGNED_FREE(ptr) free(ptr)
234
  #endif
235
+ #endif
236
 
237
  #define UNUSED GGML_UNUSED
238
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
4584
  return NULL;
4585
  }
4586
 
4587
+ // allow to call ggml_init with 0 size
4588
+ if (params.mem_size == 0) {
4589
+ params.mem_size = GGML_MEM_ALIGN;
4590
+ }
4591
+
4592
  const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
4593
 
4594
  *ctx = (struct ggml_context) {
 
4791
 
4792
  size_t obj_alloc_size = 0;
4793
 
4794
+ if (view_src == NULL && !ctx->no_alloc) {
4795
  if (ctx->scratch.data != NULL) {
4796
  // allocate tensor data in the scratch buffer
4797
  if (ctx->scratch.offs + data_size > ctx->scratch.size) {
 
5492
  }
5493
 
5494
  if (inplace) {
5495
+ GGML_ASSERT(!is_node);
5496
  }
5497
 
5498
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
5535
  }
5536
 
5537
  if (inplace) {
5538
+ GGML_ASSERT(!is_node);
5539
  }
5540
 
5541
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
19979
 
19980
  struct ggml_tensor * data = NULL;
19981
 
19982
+ if (!params.no_alloc) {
19983
  data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
19984
 
19985
  ok = ok && data != NULL;
 
20020
  }
20021
 
20022
  // point the data member to the appropriate location in the binary blob using the tensor infos
20023
+ if (!params.no_alloc) {
20024
  //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
20025
  cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
20026
  }