ggerganov commited on
Commit
a49f5c2
·
1 Parent(s): a00d66c

metal : use F32 prec in FA kernels (llama/12688)

Browse files

* metal : use F32 prec in FA kernels

ggml-ci

* cont : fix FA vec kernel

ggml-ci

ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
4179
  // ne00*(nsg)
4180
  // each simdgroup has a full f16 head vector in shared mem to accumulate results
4181
  //
4182
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4183
 
4184
  int64_t nsgmax = 2;
4185
  while (true) {
 
4179
  // ne00*(nsg)
4180
  // each simdgroup has a full f16 head vector in shared mem to accumulate results
4181
  //
4182
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4183
 
4184
  int64_t nsgmax = 2;
4185
  while (true) {
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
3184
  threadgroup_barrier(mem_flags::mem_threadgroup);
3185
 
3186
  {
3187
- half S[Q] = { [0 ... Q-1] = 0.0f };
3188
- half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3189
 
3190
  // thread indices inside the simdgroup
3191
  // TODO: see if we can utilize quad-group functions for better performance
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
3202
 
3203
  const bool has_mask = mask != q;
3204
 
3205
- half slope = 1.0f;
3206
 
3207
  // ALiBi
3208
  if (args.max_bias > 0.0f) {
3209
  const short h = iq2;
3210
 
3211
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
3212
  const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
3213
 
3214
  slope = pow(base, exph);
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
3224
 
3225
  if (has_mask) {
3226
  // used to detect blocks full of -INF
3227
- half smax = -INFINITY;
3228
 
3229
  // load the mask in shared memory
3230
  #pragma unroll(Q)
3231
  for (short j = 0; j < Q; ++j) {
3232
  device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3233
 
3234
- const half m = pm[ic + tiisg];
3235
 
3236
  ss[j*TS + C + tiisg] = m;
3237
  smax = max(smax, m);
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
3327
  // online softmax
3328
  {
3329
  for (ushort j = 0; j < Q; ++j) {
3330
- const half m = M[j];
3331
 
3332
  // scale and apply the logitcap / mask
3333
- half s = ss[j*TS + tiisg]*args.scale;
3334
 
3335
  if (args.logit_softcap != 0.0f) {
3336
  s = args.logit_softcap*precise::tanh(s);
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
3341
 
3342
  M[j] = simd_max(max(M[j], s));
3343
 
3344
- const half ms = exp(m - M[j]);
3345
- const half vs = exp(s - M[j]);
3346
 
3347
  S[j] = S[j]*ms + simd_sum(vs);
3348
 
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
3444
 
3445
  // reduce the warps sequentially
3446
  for (ushort sg = 1; sg < nsg; ++sg) {
3447
- half S = { 0.0f };
3448
- half M = { -__FLT16_MAX__/2 };
3449
 
3450
  threadgroup_barrier(mem_flags::mem_threadgroup);
3451
 
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
3461
  // the first simdgroup accumulates the results from the other simdgroups
3462
  if (sgitg == 0) {
3463
  for (short j = 0; j < Q; ++j) {
3464
- const half S0 = ss[j*TS + 0];
3465
- const half S1 = ss[j*TS + sg*SH + 0];
3466
 
3467
- const half M0 = ss[j*TS + 1];
3468
- const half M1 = ss[j*TS + sg*SH + 1];
3469
 
3470
  M = max(M0, M1);
3471
 
3472
- const half ms0 = exp(M0 - M);
3473
- const half ms1 = exp(M1 - M);
3474
 
3475
  S = S0*ms0 + S1*ms1;
3476
 
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
3646
  constexpr short DV4 = DV/4;
3647
  constexpr short NW = N_SIMDWIDTH;
3648
  constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649
- constexpr short SH = 2*C; // shared memory per simdgroup
3650
 
3651
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3652
 
3653
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3655
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3658
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3659
 
3660
  // store the result for all queries in local memory (the O matrix from the paper)
3661
  o4_t lo[DV4/NL];
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
3684
  threadgroup_barrier(mem_flags::mem_threadgroup);
3685
 
3686
  {
3687
- half S = 0.0f;
3688
- half M = -__FLT16_MAX__/2;
3689
 
3690
  // thread indices inside the simdgroup
3691
  const short tx = tiisg%NL;
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
3703
  // pointer to the mask
3704
  device const half * pm = (device const half *) (mask + iq1*args.nb31);
3705
 
3706
- half slope = 1.0f;
3707
 
3708
  // ALiBi
3709
  if (args.max_bias > 0.0f) {
3710
  const short h = iq2;
3711
 
3712
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
3713
  const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
3714
 
3715
  slope = pow(base, exph);
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
3799
 
3800
  // online softmax
3801
  {
3802
- const half m = M;
3803
- const half s = ss[tiisg];
3804
 
3805
  M = simd_max(max(M, s));
3806
 
3807
- const half ms = exp(m - M);
3808
- const half vs = exp(s - M);
3809
 
3810
  S = S*ms + simd_sum(vs);
3811
 
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
3836
  v4_t mv;
3837
  deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
3838
 
3839
- lo[ii/NL] += mv*ms;
3840
  }
3841
  }
3842
  }
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
3907
  // parallel reduce
3908
  for (short r = nsg/2; r > 0; r >>= 1) {
3909
  if (sgitg < r) {
3910
- const half S0 = ss[ 0];
3911
- const half S1 = ss[r*SH + 0];
3912
 
3913
- const half M0 = ss[ 1];
3914
- const half M1 = ss[r*SH + 1];
3915
 
3916
- const half M = max(M0, M1);
3917
 
3918
- const half ms0 = exp(M0 - M);
3919
- const half ms1 = exp(M1 - M);
3920
 
3921
- const half S = S0*ms0 + S1*ms1;
3922
 
3923
  if (tiisg == 0) {
3924
  ss[0] = S;
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
3950
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3951
  //
3952
  #define FA_TYPES \
3953
- half4, \
3954
- half4, \
3955
- half4, \
3956
- float, \
3957
- half, half4, \
3958
  half4
3959
 
3960
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
 
3184
  threadgroup_barrier(mem_flags::mem_threadgroup);
3185
 
3186
  {
3187
+ float S[Q] = { [0 ... Q-1] = 0.0f };
3188
+ float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3189
 
3190
  // thread indices inside the simdgroup
3191
  // TODO: see if we can utilize quad-group functions for better performance
 
3202
 
3203
  const bool has_mask = mask != q;
3204
 
3205
+ float slope = 1.0f;
3206
 
3207
  // ALiBi
3208
  if (args.max_bias > 0.0f) {
3209
  const short h = iq2;
3210
 
3211
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
3212
  const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
3213
 
3214
  slope = pow(base, exph);
 
3224
 
3225
  if (has_mask) {
3226
  // used to detect blocks full of -INF
3227
+ float smax = -INFINITY;
3228
 
3229
  // load the mask in shared memory
3230
  #pragma unroll(Q)
3231
  for (short j = 0; j < Q; ++j) {
3232
  device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3233
 
3234
+ const float m = pm[ic + tiisg];
3235
 
3236
  ss[j*TS + C + tiisg] = m;
3237
  smax = max(smax, m);
 
3327
  // online softmax
3328
  {
3329
  for (ushort j = 0; j < Q; ++j) {
3330
+ const float m = M[j];
3331
 
3332
  // scale and apply the logitcap / mask
3333
+ float s = ss[j*TS + tiisg]*args.scale;
3334
 
3335
  if (args.logit_softcap != 0.0f) {
3336
  s = args.logit_softcap*precise::tanh(s);
 
3341
 
3342
  M[j] = simd_max(max(M[j], s));
3343
 
3344
+ const float ms = exp(m - M[j]);
3345
+ const float vs = exp(s - M[j]);
3346
 
3347
  S[j] = S[j]*ms + simd_sum(vs);
3348
 
 
3444
 
3445
  // reduce the warps sequentially
3446
  for (ushort sg = 1; sg < nsg; ++sg) {
3447
+ float S = { 0.0f };
3448
+ float M = { -__FLT16_MAX__/2 };
3449
 
3450
  threadgroup_barrier(mem_flags::mem_threadgroup);
3451
 
 
3461
  // the first simdgroup accumulates the results from the other simdgroups
3462
  if (sgitg == 0) {
3463
  for (short j = 0; j < Q; ++j) {
3464
+ const float S0 = ss[j*TS + 0];
3465
+ const float S1 = ss[j*TS + sg*SH + 0];
3466
 
3467
+ const float M0 = ss[j*TS + 1];
3468
+ const float M1 = ss[j*TS + sg*SH + 1];
3469
 
3470
  M = max(M0, M1);
3471
 
3472
+ const float ms0 = exp(M0 - M);
3473
+ const float ms1 = exp(M1 - M);
3474
 
3475
  S = S0*ms0 + S1*ms1;
3476
 
 
3646
  constexpr short DV4 = DV/4;
3647
  constexpr short NW = N_SIMDWIDTH;
3648
  constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649
+ constexpr short SH = 4*C; // shared memory per simdgroup
3650
 
3651
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3652
 
3653
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3655
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3658
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3659
 
3660
  // store the result for all queries in local memory (the O matrix from the paper)
3661
  o4_t lo[DV4/NL];
 
3684
  threadgroup_barrier(mem_flags::mem_threadgroup);
3685
 
3686
  {
3687
+ float S = 0.0f;
3688
+ float M = -__FLT16_MAX__/2;
3689
 
3690
  // thread indices inside the simdgroup
3691
  const short tx = tiisg%NL;
 
3703
  // pointer to the mask
3704
  device const half * pm = (device const half *) (mask + iq1*args.nb31);
3705
 
3706
+ float slope = 1.0f;
3707
 
3708
  // ALiBi
3709
  if (args.max_bias > 0.0f) {
3710
  const short h = iq2;
3711
 
3712
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
3713
  const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
3714
 
3715
  slope = pow(base, exph);
 
3799
 
3800
  // online softmax
3801
  {
3802
+ const float m = M;
3803
+ const float s = ss[tiisg];
3804
 
3805
  M = simd_max(max(M, s));
3806
 
3807
+ const float ms = exp(m - M);
3808
+ const float vs = exp(s - M);
3809
 
3810
  S = S*ms + simd_sum(vs);
3811
 
 
3836
  v4_t mv;
3837
  deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
3838
 
3839
+ lo[ii/NL] += o4_t(float4(mv)*float4(ms));
3840
  }
3841
  }
3842
  }
 
3907
  // parallel reduce
3908
  for (short r = nsg/2; r > 0; r >>= 1) {
3909
  if (sgitg < r) {
3910
+ const float S0 = ss[ 0];
3911
+ const float S1 = ss[r*(SH/2) + 0];
3912
 
3913
+ const float M0 = ss[ 1];
3914
+ const float M1 = ss[r*(SH/2) + 1];
3915
 
3916
+ const float M = max(M0, M1);
3917
 
3918
+ const float ms0 = exp(M0 - M);
3919
+ const float ms1 = exp(M1 - M);
3920
 
3921
+ const float S = S0*ms0 + S1*ms1;
3922
 
3923
  if (tiisg == 0) {
3924
  ss[0] = S;
 
3950
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3951
  //
3952
  #define FA_TYPES \
3953
+ half4, \
3954
+ half4, \
3955
+ half4, \
3956
+ float, \
3957
+ float, float4, \
3958
  half4
3959
 
3960
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;