Spaces:
Running
Running
metal : use F32 prec in FA kernels (llama/12688)
Browse files* metal : use F32 prec in FA kernels
ggml-ci
* cont : fix FA vec kernel
ggml-ci
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
|
|
| 4179 |
// ne00*(nsg)
|
| 4180 |
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
| 4181 |
//
|
| 4182 |
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) +
|
| 4183 |
|
| 4184 |
int64_t nsgmax = 2;
|
| 4185 |
while (true) {
|
|
|
|
| 4179 |
// ne00*(nsg)
|
| 4180 |
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
| 4181 |
//
|
| 4182 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
| 4183 |
|
| 4184 |
int64_t nsgmax = 2;
|
| 4185 |
while (true) {
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3184 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3185 |
|
| 3186 |
{
|
| 3187 |
-
|
| 3188 |
-
|
| 3189 |
|
| 3190 |
// thread indices inside the simdgroup
|
| 3191 |
// TODO: see if we can utilize quad-group functions for better performance
|
|
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
|
|
| 3202 |
|
| 3203 |
const bool has_mask = mask != q;
|
| 3204 |
|
| 3205 |
-
|
| 3206 |
|
| 3207 |
// ALiBi
|
| 3208 |
if (args.max_bias > 0.0f) {
|
| 3209 |
const short h = iq2;
|
| 3210 |
|
| 3211 |
-
const
|
| 3212 |
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
| 3213 |
|
| 3214 |
slope = pow(base, exph);
|
|
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
|
|
| 3224 |
|
| 3225 |
if (has_mask) {
|
| 3226 |
// used to detect blocks full of -INF
|
| 3227 |
-
|
| 3228 |
|
| 3229 |
// load the mask in shared memory
|
| 3230 |
#pragma unroll(Q)
|
| 3231 |
for (short j = 0; j < Q; ++j) {
|
| 3232 |
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
| 3233 |
|
| 3234 |
-
const
|
| 3235 |
|
| 3236 |
ss[j*TS + C + tiisg] = m;
|
| 3237 |
smax = max(smax, m);
|
|
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
|
|
| 3327 |
// online softmax
|
| 3328 |
{
|
| 3329 |
for (ushort j = 0; j < Q; ++j) {
|
| 3330 |
-
const
|
| 3331 |
|
| 3332 |
// scale and apply the logitcap / mask
|
| 3333 |
-
|
| 3334 |
|
| 3335 |
if (args.logit_softcap != 0.0f) {
|
| 3336 |
s = args.logit_softcap*precise::tanh(s);
|
|
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3341 |
|
| 3342 |
M[j] = simd_max(max(M[j], s));
|
| 3343 |
|
| 3344 |
-
const
|
| 3345 |
-
const
|
| 3346 |
|
| 3347 |
S[j] = S[j]*ms + simd_sum(vs);
|
| 3348 |
|
|
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3444 |
|
| 3445 |
// reduce the warps sequentially
|
| 3446 |
for (ushort sg = 1; sg < nsg; ++sg) {
|
| 3447 |
-
|
| 3448 |
-
|
| 3449 |
|
| 3450 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3451 |
|
|
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
|
|
| 3461 |
// the first simdgroup accumulates the results from the other simdgroups
|
| 3462 |
if (sgitg == 0) {
|
| 3463 |
for (short j = 0; j < Q; ++j) {
|
| 3464 |
-
const
|
| 3465 |
-
const
|
| 3466 |
|
| 3467 |
-
const
|
| 3468 |
-
const
|
| 3469 |
|
| 3470 |
M = max(M0, M1);
|
| 3471 |
|
| 3472 |
-
const
|
| 3473 |
-
const
|
| 3474 |
|
| 3475 |
S = S0*ms0 + S1*ms1;
|
| 3476 |
|
|
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3646 |
constexpr short DV4 = DV/4;
|
| 3647 |
constexpr short NW = N_SIMDWIDTH;
|
| 3648 |
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
| 3649 |
-
constexpr short SH =
|
| 3650 |
|
| 3651 |
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3652 |
|
| 3653 |
-
//threadgroup q_t
|
| 3654 |
-
threadgroup q4_t
|
| 3655 |
-
threadgroup s_t
|
| 3656 |
-
threadgroup s4_t
|
| 3657 |
-
threadgroup
|
| 3658 |
-
threadgroup o4_t
|
| 3659 |
|
| 3660 |
// store the result for all queries in local memory (the O matrix from the paper)
|
| 3661 |
o4_t lo[DV4/NL];
|
|
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3684 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3685 |
|
| 3686 |
{
|
| 3687 |
-
|
| 3688 |
-
|
| 3689 |
|
| 3690 |
// thread indices inside the simdgroup
|
| 3691 |
const short tx = tiisg%NL;
|
|
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3703 |
// pointer to the mask
|
| 3704 |
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
| 3705 |
|
| 3706 |
-
|
| 3707 |
|
| 3708 |
// ALiBi
|
| 3709 |
if (args.max_bias > 0.0f) {
|
| 3710 |
const short h = iq2;
|
| 3711 |
|
| 3712 |
-
const
|
| 3713 |
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
| 3714 |
|
| 3715 |
slope = pow(base, exph);
|
|
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3799 |
|
| 3800 |
// online softmax
|
| 3801 |
{
|
| 3802 |
-
const
|
| 3803 |
-
const
|
| 3804 |
|
| 3805 |
M = simd_max(max(M, s));
|
| 3806 |
|
| 3807 |
-
const
|
| 3808 |
-
const
|
| 3809 |
|
| 3810 |
S = S*ms + simd_sum(vs);
|
| 3811 |
|
|
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3836 |
v4_t mv;
|
| 3837 |
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
| 3838 |
|
| 3839 |
-
lo[ii/NL] += mv*ms;
|
| 3840 |
}
|
| 3841 |
}
|
| 3842 |
}
|
|
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3907 |
// parallel reduce
|
| 3908 |
for (short r = nsg/2; r > 0; r >>= 1) {
|
| 3909 |
if (sgitg < r) {
|
| 3910 |
-
const
|
| 3911 |
-
const
|
| 3912 |
|
| 3913 |
-
const
|
| 3914 |
-
const
|
| 3915 |
|
| 3916 |
-
const
|
| 3917 |
|
| 3918 |
-
const
|
| 3919 |
-
const
|
| 3920 |
|
| 3921 |
-
const
|
| 3922 |
|
| 3923 |
if (tiisg == 0) {
|
| 3924 |
ss[0] = S;
|
|
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3950 |
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
| 3951 |
//
|
| 3952 |
#define FA_TYPES \
|
| 3953 |
-
half4,
|
| 3954 |
-
half4,
|
| 3955 |
-
half4,
|
| 3956 |
-
float,
|
| 3957 |
-
|
| 3958 |
half4
|
| 3959 |
|
| 3960 |
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
|
|
|
| 3184 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3185 |
|
| 3186 |
{
|
| 3187 |
+
float S[Q] = { [0 ... Q-1] = 0.0f };
|
| 3188 |
+
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
| 3189 |
|
| 3190 |
// thread indices inside the simdgroup
|
| 3191 |
// TODO: see if we can utilize quad-group functions for better performance
|
|
|
|
| 3202 |
|
| 3203 |
const bool has_mask = mask != q;
|
| 3204 |
|
| 3205 |
+
float slope = 1.0f;
|
| 3206 |
|
| 3207 |
// ALiBi
|
| 3208 |
if (args.max_bias > 0.0f) {
|
| 3209 |
const short h = iq2;
|
| 3210 |
|
| 3211 |
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
| 3212 |
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
| 3213 |
|
| 3214 |
slope = pow(base, exph);
|
|
|
|
| 3224 |
|
| 3225 |
if (has_mask) {
|
| 3226 |
// used to detect blocks full of -INF
|
| 3227 |
+
float smax = -INFINITY;
|
| 3228 |
|
| 3229 |
// load the mask in shared memory
|
| 3230 |
#pragma unroll(Q)
|
| 3231 |
for (short j = 0; j < Q; ++j) {
|
| 3232 |
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
| 3233 |
|
| 3234 |
+
const float m = pm[ic + tiisg];
|
| 3235 |
|
| 3236 |
ss[j*TS + C + tiisg] = m;
|
| 3237 |
smax = max(smax, m);
|
|
|
|
| 3327 |
// online softmax
|
| 3328 |
{
|
| 3329 |
for (ushort j = 0; j < Q; ++j) {
|
| 3330 |
+
const float m = M[j];
|
| 3331 |
|
| 3332 |
// scale and apply the logitcap / mask
|
| 3333 |
+
float s = ss[j*TS + tiisg]*args.scale;
|
| 3334 |
|
| 3335 |
if (args.logit_softcap != 0.0f) {
|
| 3336 |
s = args.logit_softcap*precise::tanh(s);
|
|
|
|
| 3341 |
|
| 3342 |
M[j] = simd_max(max(M[j], s));
|
| 3343 |
|
| 3344 |
+
const float ms = exp(m - M[j]);
|
| 3345 |
+
const float vs = exp(s - M[j]);
|
| 3346 |
|
| 3347 |
S[j] = S[j]*ms + simd_sum(vs);
|
| 3348 |
|
|
|
|
| 3444 |
|
| 3445 |
// reduce the warps sequentially
|
| 3446 |
for (ushort sg = 1; sg < nsg; ++sg) {
|
| 3447 |
+
float S = { 0.0f };
|
| 3448 |
+
float M = { -__FLT16_MAX__/2 };
|
| 3449 |
|
| 3450 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3451 |
|
|
|
|
| 3461 |
// the first simdgroup accumulates the results from the other simdgroups
|
| 3462 |
if (sgitg == 0) {
|
| 3463 |
for (short j = 0; j < Q; ++j) {
|
| 3464 |
+
const float S0 = ss[j*TS + 0];
|
| 3465 |
+
const float S1 = ss[j*TS + sg*SH + 0];
|
| 3466 |
|
| 3467 |
+
const float M0 = ss[j*TS + 1];
|
| 3468 |
+
const float M1 = ss[j*TS + sg*SH + 1];
|
| 3469 |
|
| 3470 |
M = max(M0, M1);
|
| 3471 |
|
| 3472 |
+
const float ms0 = exp(M0 - M);
|
| 3473 |
+
const float ms1 = exp(M1 - M);
|
| 3474 |
|
| 3475 |
S = S0*ms0 + S1*ms1;
|
| 3476 |
|
|
|
|
| 3646 |
constexpr short DV4 = DV/4;
|
| 3647 |
constexpr short NW = N_SIMDWIDTH;
|
| 3648 |
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
| 3649 |
+
constexpr short SH = 4*C; // shared memory per simdgroup
|
| 3650 |
|
| 3651 |
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3652 |
|
| 3653 |
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
| 3654 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
| 3655 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
| 3656 |
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
| 3657 |
+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
| 3658 |
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
| 3659 |
|
| 3660 |
// store the result for all queries in local memory (the O matrix from the paper)
|
| 3661 |
o4_t lo[DV4/NL];
|
|
|
|
| 3684 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3685 |
|
| 3686 |
{
|
| 3687 |
+
float S = 0.0f;
|
| 3688 |
+
float M = -__FLT16_MAX__/2;
|
| 3689 |
|
| 3690 |
// thread indices inside the simdgroup
|
| 3691 |
const short tx = tiisg%NL;
|
|
|
|
| 3703 |
// pointer to the mask
|
| 3704 |
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
| 3705 |
|
| 3706 |
+
float slope = 1.0f;
|
| 3707 |
|
| 3708 |
// ALiBi
|
| 3709 |
if (args.max_bias > 0.0f) {
|
| 3710 |
const short h = iq2;
|
| 3711 |
|
| 3712 |
+
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
| 3713 |
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
| 3714 |
|
| 3715 |
slope = pow(base, exph);
|
|
|
|
| 3799 |
|
| 3800 |
// online softmax
|
| 3801 |
{
|
| 3802 |
+
const float m = M;
|
| 3803 |
+
const float s = ss[tiisg];
|
| 3804 |
|
| 3805 |
M = simd_max(max(M, s));
|
| 3806 |
|
| 3807 |
+
const float ms = exp(m - M);
|
| 3808 |
+
const float vs = exp(s - M);
|
| 3809 |
|
| 3810 |
S = S*ms + simd_sum(vs);
|
| 3811 |
|
|
|
|
| 3836 |
v4_t mv;
|
| 3837 |
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
| 3838 |
|
| 3839 |
+
lo[ii/NL] += o4_t(float4(mv)*float4(ms));
|
| 3840 |
}
|
| 3841 |
}
|
| 3842 |
}
|
|
|
|
| 3907 |
// parallel reduce
|
| 3908 |
for (short r = nsg/2; r > 0; r >>= 1) {
|
| 3909 |
if (sgitg < r) {
|
| 3910 |
+
const float S0 = ss[ 0];
|
| 3911 |
+
const float S1 = ss[r*(SH/2) + 0];
|
| 3912 |
|
| 3913 |
+
const float M0 = ss[ 1];
|
| 3914 |
+
const float M1 = ss[r*(SH/2) + 1];
|
| 3915 |
|
| 3916 |
+
const float M = max(M0, M1);
|
| 3917 |
|
| 3918 |
+
const float ms0 = exp(M0 - M);
|
| 3919 |
+
const float ms1 = exp(M1 - M);
|
| 3920 |
|
| 3921 |
+
const float S = S0*ms0 + S1*ms1;
|
| 3922 |
|
| 3923 |
if (tiisg == 0) {
|
| 3924 |
ss[0] = S;
|
|
|
|
| 3950 |
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
| 3951 |
//
|
| 3952 |
#define FA_TYPES \
|
| 3953 |
+
half4, \
|
| 3954 |
+
half4, \
|
| 3955 |
+
half4, \
|
| 3956 |
+
float, \
|
| 3957 |
+
float, float4, \
|
| 3958 |
half4
|
| 3959 |
|
| 3960 |
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|