Spaces:
Running
Running
metal : use constexpr in FA kernels + fix typedef (llama/12659)
Browse files* metal : use constexpr in FA kernels
ggml-ci
* cont
ggml-ci
* cont : fix typedef
ggml-ci
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
|
|
| 3128 |
const int iq2 = tgpig[1];
|
| 3129 |
const int iq1 = tgpig[0]*Q;
|
| 3130 |
|
| 3131 |
-
|
| 3132 |
-
|
| 3133 |
-
|
| 3134 |
-
|
| 3135 |
-
|
| 3136 |
-
|
| 3137 |
-
|
| 3138 |
-
|
|
|
|
| 3139 |
|
| 3140 |
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
| 3141 |
const short T = DK + 2*TS; // shared memory size per query in (half)
|
|
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3641 |
const int iq2 = tgpig[1];
|
| 3642 |
const int iq1 = tgpig[0];
|
| 3643 |
|
| 3644 |
-
|
| 3645 |
-
|
| 3646 |
-
|
| 3647 |
-
|
| 3648 |
-
|
| 3649 |
|
| 3650 |
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3651 |
|
|
@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3956 |
half, half4, \
|
| 3957 |
half4
|
| 3958 |
|
| 3959 |
-
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128,
|
| 3960 |
|
| 3961 |
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
|
| 3962 |
#if defined(GGML_METAL_USE_BF16)
|
|
|
|
| 3128 |
const int iq2 = tgpig[1];
|
| 3129 |
const int iq1 = tgpig[0]*Q;
|
| 3130 |
|
| 3131 |
+
constexpr short DK4 = DK/4;
|
| 3132 |
+
constexpr short DK8 = DK/8;
|
| 3133 |
+
constexpr short DK16 = DK/16;
|
| 3134 |
+
constexpr short DV4 = DV/4;
|
| 3135 |
+
constexpr short DV8 = DV/8;
|
| 3136 |
+
constexpr short DV16 = DV/16;
|
| 3137 |
+
|
| 3138 |
+
constexpr short NW = N_SIMDWIDTH;
|
| 3139 |
+
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
| 3140 |
|
| 3141 |
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
| 3142 |
const short T = DK + 2*TS; // shared memory size per query in (half)
|
|
|
|
| 3642 |
const int iq2 = tgpig[1];
|
| 3643 |
const int iq1 = tgpig[0];
|
| 3644 |
|
| 3645 |
+
constexpr short DK4 = DK/4;
|
| 3646 |
+
constexpr short DV4 = DV/4;
|
| 3647 |
+
constexpr short NW = N_SIMDWIDTH;
|
| 3648 |
+
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
| 3649 |
+
constexpr short SH = 2*C; // shared memory per simdgroup
|
| 3650 |
|
| 3651 |
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3652 |
|
|
|
|
| 3957 |
half, half4, \
|
| 3958 |
half4
|
| 3959 |
|
| 3960 |
+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
| 3961 |
|
| 3962 |
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
|
| 3963 |
#if defined(GGML_METAL_USE_BF16)
|