ggerganov commited on
Commit
c699617
·
1 Parent(s): 12bb60d

metal : use constexpr in FA kernels + fix typedef (llama/12659)

Browse files

* metal : use constexpr in FA kernels

ggml-ci

* cont

ggml-ci

* cont : fix typedef

ggml-ci

Files changed (1) hide show
  1. ggml/src/ggml-metal/ggml-metal.metal +15 -14
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
3128
  const int iq2 = tgpig[1];
3129
  const int iq1 = tgpig[0]*Q;
3130
 
3131
- const short DK4 = DK/4;
3132
- const short DK8 = DK/8;
3133
- const short DK16 = DK/16;
3134
- const short DV4 = DV/4;
3135
- const short DV8 = DV/8;
3136
- const short DV16 = DV/16;
3137
- const short NW = N_SIMDWIDTH;
3138
- const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
3139
 
3140
  const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3141
  const short T = DK + 2*TS; // shared memory size per query in (half)
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
3641
  const int iq2 = tgpig[1];
3642
  const int iq1 = tgpig[0];
3643
 
3644
- const short DK4 = DK/4;
3645
- const short DV4 = DV/4;
3646
- const short NW = N_SIMDWIDTH;
3647
- const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3648
- const short SH = 2*C; // shared memory per simdgroup
3649
 
3650
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3651
 
@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
3956
  half, half4, \
3957
  half4
3958
 
3959
- typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
3960
 
3961
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
3962
  #if defined(GGML_METAL_USE_BF16)
 
3128
  const int iq2 = tgpig[1];
3129
  const int iq1 = tgpig[0]*Q;
3130
 
3131
+ constexpr short DK4 = DK/4;
3132
+ constexpr short DK8 = DK/8;
3133
+ constexpr short DK16 = DK/16;
3134
+ constexpr short DV4 = DV/4;
3135
+ constexpr short DV8 = DV/8;
3136
+ constexpr short DV16 = DV/16;
3137
+
3138
+ constexpr short NW = N_SIMDWIDTH;
3139
+ constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3140
 
3141
  const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3142
  const short T = DK + 2*TS; // shared memory size per query in (half)
 
3642
  const int iq2 = tgpig[1];
3643
  const int iq1 = tgpig[0];
3644
 
3645
+ constexpr short DK4 = DK/4;
3646
+ constexpr short DV4 = DV/4;
3647
+ constexpr short NW = N_SIMDWIDTH;
3648
+ constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649
+ constexpr short SH = 2*C; // shared memory per simdgroup
3650
 
3651
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3652
 
 
3957
  half, half4, \
3958
  half4
3959
 
3960
+ typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
3961
 
3962
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
3963
  #if defined(GGML_METAL_USE_BF16)