ggerganov commited on
Commit
e093044
·
1 Parent(s): ac537d2

metal : fix floating-point range of attention scores in FA kernels (llama/13090)

Browse files
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
3192
 
3193
  {
3194
  float S[Q] = { [0 ... Q-1] = 0.0f };
3195
- float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3196
 
3197
  // thread indices inside the simdgroup
3198
  // TODO: see if we can utilize quad-group functions for better performance
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
3452
  // reduce the warps sequentially
3453
  for (ushort sg = 1; sg < nsg; ++sg) {
3454
  float S = { 0.0f };
3455
- float M = { -__FLT16_MAX__/2 };
3456
 
3457
  threadgroup_barrier(mem_flags::mem_threadgroup);
3458
 
@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
3699
 
3700
  {
3701
  float S = 0.0f;
3702
- float M = -__FLT16_MAX__/2;
3703
 
3704
  // thread indices inside the simdgroup
3705
  const short tx = tiisg%NL;
 
3192
 
3193
  {
3194
  float S[Q] = { [0 ... Q-1] = 0.0f };
3195
+ float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
3196
 
3197
  // thread indices inside the simdgroup
3198
  // TODO: see if we can utilize quad-group functions for better performance
 
3452
  // reduce the warps sequentially
3453
  for (ushort sg = 1; sg < nsg; ++sg) {
3454
  float S = { 0.0f };
3455
+ float M = { -__FLT_MAX__/2 };
3456
 
3457
  threadgroup_barrier(mem_flags::mem_threadgroup);
3458
 
 
3699
 
3700
  {
3701
  float S = 0.0f;
3702
+ float M = -__FLT_MAX__/2;
3703
 
3704
  // thread indices inside the simdgroup
3705
  const short tx = tiisg%NL;