Spaces:
Sleeping
Sleeping
metal : fix floating-point range of attention scores in FA kernels (llama/13090)
Browse files
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3192 |
|
| 3193 |
{
|
| 3194 |
float S[Q] = { [0 ... Q-1] = 0.0f };
|
| 3195 |
-
float M[Q] = { [0 ... Q-1] = -
|
| 3196 |
|
| 3197 |
// thread indices inside the simdgroup
|
| 3198 |
// TODO: see if we can utilize quad-group functions for better performance
|
|
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3452 |
// reduce the warps sequentially
|
| 3453 |
for (ushort sg = 1; sg < nsg; ++sg) {
|
| 3454 |
float S = { 0.0f };
|
| 3455 |
-
float M = { -
|
| 3456 |
|
| 3457 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3458 |
|
|
@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3699 |
|
| 3700 |
{
|
| 3701 |
float S = 0.0f;
|
| 3702 |
-
float M = -
|
| 3703 |
|
| 3704 |
// thread indices inside the simdgroup
|
| 3705 |
const short tx = tiisg%NL;
|
|
|
|
| 3192 |
|
| 3193 |
{
|
| 3194 |
float S[Q] = { [0 ... Q-1] = 0.0f };
|
| 3195 |
+
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
|
| 3196 |
|
| 3197 |
// thread indices inside the simdgroup
|
| 3198 |
// TODO: see if we can utilize quad-group functions for better performance
|
|
|
|
| 3452 |
// reduce the warps sequentially
|
| 3453 |
for (ushort sg = 1; sg < nsg; ++sg) {
|
| 3454 |
float S = { 0.0f };
|
| 3455 |
+
float M = { -__FLT_MAX__/2 };
|
| 3456 |
|
| 3457 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3458 |
|
|
|
|
| 3699 |
|
| 3700 |
{
|
| 3701 |
float S = 0.0f;
|
| 3702 |
+
float M = -__FLT_MAX__/2;
|
| 3703 |
|
| 3704 |
// thread indices inside the simdgroup
|
| 3705 |
const short tx = tiisg%NL;
|