ngxson HF Staff commited on
Commit
27b32e6
·
1 Parent(s): d51c0d3

ggml : add mrope kernel for metal (llama/13457)

Browse files
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -207,6 +207,10 @@ typedef struct {
207
  float attn_factor;
208
  float beta_fast;
209
  float beta_slow;
 
 
 
 
210
  } ggml_metal_kargs_rope;
211
 
212
  typedef struct {
 
207
  float attn_factor;
208
  float beta_fast;
209
  float beta_slow;
210
+ int32_t sect_0;
211
+ int32_t sect_1;
212
+ int32_t sect_2;
213
+ int32_t sect_3;
214
  } ggml_metal_kargs_rope;
215
 
216
  typedef struct {
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -332,6 +332,10 @@ enum ggml_metal_kernel_type {
332
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
333
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
334
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
 
 
 
 
335
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
336
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
337
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1275
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
1276
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1277
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
 
 
 
 
1278
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
1279
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
1280
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1637
  case GGML_OP_NORM:
1638
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1639
  case GGML_OP_ROPE:
1640
- {
1641
- const int mode = ((const int32_t *) op->op_params)[2];
1642
- if (mode & GGML_ROPE_TYPE_MROPE) {
1643
- return false;
1644
- }
1645
- if (mode & GGML_ROPE_TYPE_VISION) {
1646
- return false;
1647
- }
1648
- return true;
1649
- }
1650
  case GGML_OP_IM2COL:
1651
  return op->src[0]->type == GGML_TYPE_F16;
1652
  case GGML_OP_POOL_1D:
@@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node(
3826
  } break;
3827
  case GGML_OP_ROPE:
3828
  {
 
3829
  // make sure we have one or more position id(ne10) per token(ne02)
3830
  GGML_ASSERT(ne10 % ne02 == 0);
3831
  GGML_ASSERT(ne10 >= ne02);
@@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node(
3852
  memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
3853
  memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
3854
 
3855
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 
 
 
 
 
 
 
 
3856
 
3857
  id<MTLComputePipelineState> pipeline = nil;
3858
 
3859
- if (!is_neox) {
3860
  switch (src0->type) {
3861
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3862
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3863
  default: GGML_ABORT("fatal error");
3864
  };
3865
  } else {
3866
  switch (src0->type) {
3867
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3868
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3869
  default: GGML_ABORT("fatal error");
3870
  };
3871
  }
@@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node(
3896
  /*.attn_factor =*/ attn_factor,
3897
  /*.beta_fast =*/ beta_fast,
3898
  /*.beta_slow =*/ beta_slow,
 
 
 
 
3899
  };
3900
 
3901
  [encoder setComputePipelineState:pipeline];
 
332
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
333
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
334
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
335
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
336
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
337
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
338
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
339
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
340
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
341
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
 
1279
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
1280
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1281
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1282
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
1283
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
1284
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
1285
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
1286
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
1287
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
1288
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
 
1645
  case GGML_OP_NORM:
1646
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1647
  case GGML_OP_ROPE:
1648
+ return true;
 
 
 
 
 
 
 
 
 
1649
  case GGML_OP_IM2COL:
1650
  return op->src[0]->type == GGML_TYPE_F16;
1651
  case GGML_OP_POOL_1D:
 
3825
  } break;
3826
  case GGML_OP_ROPE:
3827
  {
3828
+
3829
  // make sure we have one or more position id(ne10) per token(ne02)
3830
  GGML_ASSERT(ne10 % ne02 == 0);
3831
  GGML_ASSERT(ne10 >= ne02);
 
3852
  memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
3853
  memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
3854
 
3855
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3856
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
3857
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
3858
+
3859
+ // mrope
3860
+ const int sect_0 = ((const int32_t *) dst->op_params)[11];
3861
+ const int sect_1 = ((const int32_t *) dst->op_params)[12];
3862
+ const int sect_2 = ((const int32_t *) dst->op_params)[13];
3863
+ const int sect_3 = ((const int32_t *) dst->op_params)[14];
3864
 
3865
  id<MTLComputePipelineState> pipeline = nil;
3866
 
3867
+ if (is_neox) {
3868
  switch (src0->type) {
3869
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3870
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3871
+ default: GGML_ABORT("fatal error");
3872
+ };
3873
+ } else if (is_mrope && !is_vision) {
3874
+ GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3875
+ switch (src0->type) {
3876
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
3877
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
3878
+ default: GGML_ABORT("fatal error");
3879
+ };
3880
+ } else if (is_vision) {
3881
+ GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3882
+ switch (src0->type) {
3883
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
3884
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
3885
  default: GGML_ABORT("fatal error");
3886
  };
3887
  } else {
3888
  switch (src0->type) {
3889
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3890
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3891
  default: GGML_ABORT("fatal error");
3892
  };
3893
  }
 
3918
  /*.attn_factor =*/ attn_factor,
3919
  /*.beta_fast =*/ beta_fast,
3920
  /*.beta_slow =*/ beta_slow,
3921
+ /* sect_0 =*/ sect_0,
3922
+ /* sect_1 =*/ sect_1,
3923
+ /* sect_2 =*/ sect_2,
3924
+ /* sect_3 =*/ sect_3,
3925
  };
3926
 
3927
  [encoder setComputePipelineState:pipeline];
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
2713
  }
2714
  }
2715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2716
  typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
2717
  typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
 
 
2718
 
2719
  template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
2720
  template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
2722
  template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
2723
  template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
2724
 
 
 
 
 
 
 
2725
  typedef void (im2col_t)(
2726
  device const float * x,
2727
  device char * dst,
 
2713
  }
2714
  }
2715
 
2716
+ template<typename T>
2717
+ kernel void kernel_rope_multi(
2718
+ constant ggml_metal_kargs_rope & args,
2719
+ device const char * src0,
2720
+ device const char * src1,
2721
+ device const char * src2,
2722
+ device char * dst,
2723
+ ushort tiitg[[thread_index_in_threadgroup]],
2724
+ ushort3 tptg [[threads_per_threadgroup]],
2725
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2726
+ const int i3 = tgpig[2];
2727
+ const int i2 = tgpig[1];
2728
+ const int i1 = tgpig[0];
2729
+
2730
+ float corr_dims[2];
2731
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2732
+
2733
+ device const int32_t * pos = (device const int32_t *) src1;
2734
+
2735
+ const float inv_ndims = -1.f/args.n_dims;
2736
+
2737
+ float cos_theta;
2738
+ float sin_theta;
2739
+
2740
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2741
+ if (i0 < args.n_dims) {
2742
+ const int ic = i0/2;
2743
+
2744
+ // mrope theta calculations
2745
+ // note: the rest is the same as kernel_rope_neox
2746
+ const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
2747
+ const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
2748
+ const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
2749
+ const int sector = ic % sect_dims;
2750
+
2751
+ float theta_base;
2752
+ if (sector < args.sect_0) {
2753
+ theta_base = (float) pos[i2];
2754
+ } else if (sector < sec_w01) {
2755
+ theta_base = (float) pos[i2 + args.ne02];
2756
+ } else if (sector < sec_w012) {
2757
+ theta_base = (float) pos[i2 + args.ne02 * 2];
2758
+ } else {
2759
+ theta_base = (float) pos[i2 + args.ne02 * 3];
2760
+ }
2761
+ // end of mrope
2762
+
2763
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
2764
+
2765
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2766
+
2767
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2768
+
2769
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2770
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2771
+
2772
+ const float x0 = src[0];
2773
+ const float x1 = src[args.n_dims/2];
2774
+
2775
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
2776
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
2777
+ } else {
2778
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2779
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2780
+
2781
+ dst_data[0] = src[0];
2782
+ dst_data[1] = src[1];
2783
+ }
2784
+ }
2785
+ }
2786
+
2787
+ template<typename T>
2788
+ kernel void kernel_rope_vision(
2789
+ constant ggml_metal_kargs_rope & args,
2790
+ device const char * src0,
2791
+ device const char * src1,
2792
+ device const char * src2,
2793
+ device char * dst,
2794
+ ushort tiitg[[thread_index_in_threadgroup]],
2795
+ ushort3 tptg [[threads_per_threadgroup]],
2796
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2797
+ const int i3 = tgpig[2];
2798
+ const int i2 = tgpig[1];
2799
+ const int i1 = tgpig[0];
2800
+
2801
+ float corr_dims[2];
2802
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2803
+
2804
+ device const int32_t * pos = (device const int32_t *) src1;
2805
+
2806
+ const float inv_ndims = -1.f/args.n_dims;
2807
+
2808
+ float cos_theta;
2809
+ float sin_theta;
2810
+
2811
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2812
+ if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
2813
+ const int ic = i0/2;
2814
+
2815
+ // mrope theta calculations (only support 2 dimensions)
2816
+ const int sect_dims = args.sect_0 + args.sect_1;
2817
+ const int sector = ic % sect_dims;
2818
+
2819
+ float p;
2820
+ float theta_base;
2821
+ if (sector < args.sect_1) {
2822
+ p = (float) sector;
2823
+ theta_base = (float) pos[i2];
2824
+ } else {
2825
+ p = (float) sector - args.sect_0;
2826
+ theta_base = (float) pos[i2 + args.ne02];
2827
+ }
2828
+
2829
+ const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
2830
+ // end of mrope
2831
+
2832
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2833
+
2834
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2835
+
2836
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2837
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2838
+
2839
+ const float x0 = src[0];
2840
+ const float x1 = src[args.n_dims]; // different from kernel_rope_multi
2841
+
2842
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
2843
+ dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
2844
+ } else {
2845
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2846
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2847
+
2848
+ dst_data[0] = src[0];
2849
+ dst_data[1] = src[1];
2850
+ }
2851
+ }
2852
+ }
2853
+
2854
  typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
2855
  typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
2856
+ typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
2857
+ typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
2858
 
2859
  template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
2860
  template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
 
2862
  template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
2863
  template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
2864
 
2865
+ template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
2866
+ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
2867
+
2868
+ template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
2869
+ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
2870
+
2871
  typedef void (im2col_t)(
2872
  device const float * x,
2873
  device char * dst,