JohannesGaessler commited on
Commit
ca79691
·
1 Parent(s): 2692ce5

CUDA: generalize FP16 fattn vec kernel (llama/7061)

Browse files

* CUDA: generalize FP16 fattn vec kernel

* disable unsupported head sizes for AMD in test

* try AMD fix

* fix batch size 2-8

* partially revert changes

Files changed (2) hide show
  1. ggml-cuda/common.cuh +124 -108
  2. ggml-cuda/fattn.cu +223 -82
ggml-cuda/common.cuh CHANGED
@@ -234,6 +234,97 @@ typedef float dfloat; // dequantize float
234
  typedef float2 dfloat2;
235
  #endif //GGML_CUDA_F16
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  [[noreturn]]
238
  static __device__ void no_device_code(
239
  const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -275,16 +366,28 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
275
  }
276
 
277
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
278
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 
 
279
  #pragma unroll
280
- for (int mask = 16; mask > 0; mask >>= 1) {
281
- a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
282
- }
283
- return a;
 
 
284
  #else
285
- GGML_UNUSED(a);
286
- NO_DEVICE_CODE;
287
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 
 
 
 
 
 
 
 
288
  }
289
 
290
  static __device__ __forceinline__ float warp_reduce_max(float x) {
@@ -296,20 +399,21 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
296
  }
297
 
298
  static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
299
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
300
 
301
- #if CUDART_VERSION >= CUDART_HMAX
302
- return __hmax(a, b);
303
  #else
304
- return __half2float(a) > __half2float(b) ? a : b;
305
- #endif // CUDART_VERSION >= CUDART_HMAX
306
 
307
  #else
308
- GGML_UNUSED(a);
309
- GGML_UNUSED(b);
310
- NO_DEVICE_CODE;
311
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
312
  }
 
313
  static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
314
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
315
 
@@ -317,8 +421,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
317
  return __hmax2(a, b);
318
  #else
319
  half2 ret;
320
- reinterpret_cast<half&>(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b);
321
- reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b);
322
  return ret;
323
  #endif // CUDART_VERSION >= CUDART_HMAX
324
 
@@ -326,7 +430,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
326
  GGML_UNUSED(a);
327
  GGML_UNUSED(b);
328
  NO_DEVICE_CODE;
329
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
330
  }
331
 
332
  static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
@@ -350,94 +454,6 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
350
  }
351
  #endif // CUDART_VERSION < 12000
352
 
353
- #if defined(GGML_USE_HIPBLAS)
354
- #define __CUDA_ARCH__ 1300
355
-
356
- #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
357
- defined(__gfx1150__) || defined(__gfx1151__)
358
- #define RDNA3
359
- #endif
360
-
361
- #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
362
- defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
363
- #define RDNA2
364
- #endif
365
-
366
- #ifndef __has_builtin
367
- #define __has_builtin(x) 0
368
- #endif
369
-
370
- typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
371
- typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
372
- static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
373
- const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
374
- const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
375
- #if __has_builtin(__builtin_elementwise_sub_sat)
376
- const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
377
- return reinterpret_cast<const int &>(c);
378
- #else
379
- int8x4_t c;
380
- int16_t tmp;
381
- #pragma unroll
382
- for (int i = 0; i < 4; i++) {
383
- tmp = va[i] - vb[i];
384
- if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
385
- if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
386
- c[i] = tmp;
387
- }
388
- return reinterpret_cast<int &>(c);
389
- #endif // __has_builtin(__builtin_elementwise_sub_sat)
390
- }
391
-
392
- static __device__ __forceinline__ int __vsub4(const int a, const int b) {
393
- return __vsubss4(a, b);
394
- }
395
-
396
- static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
397
- const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
398
- const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
399
- unsigned int c;
400
- uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
401
- #pragma unroll
402
- for (int i = 0; i < 4; ++i) {
403
- vc[i] = va[i] == vb[i] ? 0xff : 0x00;
404
- }
405
- return c;
406
- }
407
-
408
- static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
409
- #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
410
- c = __builtin_amdgcn_sdot4(a, b, c, false);
411
- #elif defined(RDNA3)
412
- c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
413
- #elif defined(__gfx1010__) || defined(__gfx900__)
414
- int tmp1;
415
- int tmp2;
416
- asm("\n \
417
- v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
418
- v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
419
- v_add3_u32 %0, %1, %2, %0 \n \
420
- v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
421
- v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
422
- v_add3_u32 %0, %1, %2, %0 \n \
423
- "
424
- : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
425
- : "v"(a), "v"(b)
426
- );
427
- #else
428
- const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
429
- const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
430
- c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
431
- #endif
432
- return c;
433
- }
434
- #endif // defined(GGML_USE_HIPBLAS)
435
-
436
- #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
437
- defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
438
-
439
- #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
440
-
441
  // TODO: move to ggml-common.h
442
  static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
443
 
 
234
  typedef float2 dfloat2;
235
  #endif //GGML_CUDA_F16
236
 
237
+ #if defined(GGML_USE_HIPBLAS)
238
+ #define __CUDA_ARCH__ 1300
239
+
240
+ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
241
+ defined(__gfx1150__) || defined(__gfx1151__)
242
+ #define RDNA3
243
+ #endif
244
+
245
+ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
246
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
247
+ #define RDNA2
248
+ #endif
249
+
250
+ #ifndef __has_builtin
251
+ #define __has_builtin(x) 0
252
+ #endif
253
+
254
+ typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
255
+ typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
256
+ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
257
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
258
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
259
+ #if __has_builtin(__builtin_elementwise_sub_sat)
260
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
261
+ return reinterpret_cast<const int &>(c);
262
+ #else
263
+ int8x4_t c;
264
+ int16_t tmp;
265
+ #pragma unroll
266
+ for (int i = 0; i < 4; i++) {
267
+ tmp = va[i] - vb[i];
268
+ if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
269
+ if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
270
+ c[i] = tmp;
271
+ }
272
+ return reinterpret_cast<int &>(c);
273
+ #endif // __has_builtin(__builtin_elementwise_sub_sat)
274
+ }
275
+
276
+ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
277
+ return __vsubss4(a, b);
278
+ }
279
+
280
+ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
281
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
282
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
283
+ unsigned int c;
284
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
285
+ #pragma unroll
286
+ for (int i = 0; i < 4; ++i) {
287
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
288
+ }
289
+ return c;
290
+ }
291
+
292
+ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
293
+ #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
294
+ c = __builtin_amdgcn_sdot4(a, b, c, false);
295
+ #elif defined(RDNA3)
296
+ c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
297
+ #elif defined(__gfx1010__) || defined(__gfx900__)
298
+ int tmp1;
299
+ int tmp2;
300
+ asm("\n \
301
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
302
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
303
+ v_add3_u32 %0, %1, %2, %0 \n \
304
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
305
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
306
+ v_add3_u32 %0, %1, %2, %0 \n \
307
+ "
308
+ : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
309
+ : "v"(a), "v"(b)
310
+ );
311
+ #else
312
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
313
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
314
+ c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
315
+ #endif
316
+ return c;
317
+ }
318
+ #endif // defined(GGML_USE_HIPBLAS)
319
+
320
+ #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
321
+
322
+ #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
323
+
324
+ static bool fp16_mma_available(const int cc) {
325
+ return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
326
+ }
327
+
328
  [[noreturn]]
329
  static __device__ void no_device_code(
330
  const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
 
366
  }
367
 
368
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
369
+ #if FP16_AVAILABLE
370
+
371
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
372
  #pragma unroll
373
+ for (int mask = 16; mask > 0; mask >>= 1) {
374
+ const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
375
+ reinterpret_cast<half&>(a.x) += __low2half(a_other);
376
+ reinterpret_cast<half&>(a.y) += __high2half(a_other);
377
+ }
378
+ return a;
379
  #else
380
+ #pragma unroll
381
+ for (int mask = 16; mask > 0; mask >>= 1) {
382
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
383
+ }
384
+ return a;
385
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
386
+
387
+ #else
388
+ NO_DEVICE_CODE;
389
+ return a;
390
+ #endif // FP16_AVAILABLE
391
  }
392
 
393
  static __device__ __forceinline__ float warp_reduce_max(float x) {
 
399
  }
400
 
401
  static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
402
+ #if FP16_AVAILABLE
403
 
404
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
405
+ return __float2half(fmaxf(__half2float(a), __half2float(b)));
406
  #else
407
+ return __hmax(a, b);
408
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
409
 
410
  #else
411
+ NO_DEVICE_CODE;
412
+ GGML_UNUSED(b);
413
+ return a;
414
+ #endif // FP16_AVAILABLE
415
  }
416
+
417
  static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
418
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
419
 
 
421
  return __hmax2(a, b);
422
  #else
423
  half2 ret;
424
+ reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
425
+ reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
426
  return ret;
427
  #endif // CUDART_VERSION >= CUDART_HMAX
428
 
 
430
  GGML_UNUSED(a);
431
  GGML_UNUSED(b);
432
  NO_DEVICE_CODE;
433
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
434
  }
435
 
436
  static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
 
454
  }
455
  #endif // CUDART_VERSION < 12000
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  // TODO: move to ggml-common.h
458
  static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
459
 
ggml-cuda/fattn.cu CHANGED
@@ -11,8 +11,10 @@
11
  #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
12
  #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
13
 
14
- template<int D, int parallel_blocks> // D == head size
15
- __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
 
 
16
  static __global__ void flash_attn_vec_ext_f16(
17
  const char * __restrict__ Q,
18
  const char * __restrict__ K,
@@ -44,55 +46,77 @@ static __global__ void flash_attn_vec_ext_f16(
44
  #if FP16_AVAILABLE
45
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
46
 
47
- const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
48
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
49
 
50
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
51
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
52
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
53
  const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
54
- const half * maskh = (const half *) mask + ne11*ic;
55
 
56
  const int stride_KV = nb11 / sizeof(half);
57
  const int stride_KV2 = nb11 / sizeof(half2);
58
 
59
- constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
 
60
  const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
61
- __builtin_assume(tid < nwarps*WARP_SIZE);
62
 
63
- __shared__ half KQ[nwarps*WARP_SIZE];
64
- KQ[tid] = -INFINITY;
 
 
 
65
  half2 * KQ2 = (half2 *) KQ;
66
 
67
- half kqmax = -HALF_MAX_HALF;
68
- half kqsum = 0.0f;
 
 
 
 
69
 
70
- __shared__ half kqmax_shared[WARP_SIZE];
71
- __shared__ half kqsum_shared[WARP_SIZE];
72
- if (threadIdx.y == 0) {
73
- kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
74
- kqsum_shared[threadIdx.x] = 0.0f;
 
 
 
75
  }
76
  __syncthreads();
77
 
78
  // Convert Q to half2 and store in registers:
79
- half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
80
  #pragma unroll
81
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
82
- const int i = i0 + threadIdx.x;
83
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
84
- break;
85
- }
86
 
87
- Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
 
 
88
  }
89
 
90
- half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
91
 
92
- const int k_start = parallel_blocks == 1 ? 0 : ip*D;
93
  for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
94
  // Calculate KQ tile and keep track of new maximum KQ values:
95
- half kqmax_new = kqmax;
 
 
 
 
 
 
 
 
 
 
96
  #pragma unroll
97
  for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
98
  const int i_KQ = i_KQ_0 + threadIdx.y;
@@ -101,89 +125,112 @@ static __global__ void flash_attn_vec_ext_f16(
101
  break;
102
  }
103
 
104
- half2 sum2 = make_half2(0.0f, 0.0f);
105
  #pragma unroll
106
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
107
  const int k_KQ = k_KQ_0 + threadIdx.x;
108
- if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
109
- break;
110
- }
111
 
112
  const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
113
- sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
 
 
 
114
  }
115
 
116
- sum2 = warp_reduce_sum(sum2);
117
- half sum = __low2half(sum2) + __high2half(sum2);
118
- sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
119
- kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
120
- if (threadIdx.x == 0) {
121
- KQ[i_KQ] = sum;
 
 
 
 
 
 
 
 
 
122
  }
123
  }
124
 
125
- kqmax_new = warp_reduce_max(kqmax_new);
126
- if (threadIdx.x == 0) {
127
- kqmax_shared[threadIdx.y] = kqmax_new;
 
 
 
 
 
128
  }
 
129
  __syncthreads();
130
- kqmax_new = kqmax_shared[threadIdx.x];
131
- kqmax_new = warp_reduce_max(kqmax_new);
132
 
133
- const half KQ_max_scale = hexp(kqmax - kqmax_new);
134
- kqmax = kqmax_new;
 
 
 
 
 
135
 
136
- const half val = hexp(KQ[tid] - kqmax);
137
- kqsum = kqsum*KQ_max_scale + val;
138
- KQ[tid] = val;
139
 
140
- VKQ *= __half2half2(KQ_max_scale);
 
141
 
142
  __syncthreads();
143
 
144
- if (tid < D) {
145
  #pragma unroll
146
- for (int k0 = 0; k0 < D; k0 += 2) {
147
- if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
148
- break;
149
- }
150
 
151
- half2 V_k;
152
- reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
153
- reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
154
- VKQ += V_k*KQ2[k0/2];
 
 
155
  }
156
  }
157
 
158
  __syncthreads();
159
  }
160
 
161
- if (tid >= D) {
162
- kqsum = 0.0f;
 
 
 
 
163
  }
164
 
165
- kqsum = warp_reduce_sum(kqsum);
166
- if (threadIdx.x == 0) {
167
- kqsum_shared[threadIdx.y] = kqsum;
168
- }
169
  __syncthreads();
170
- kqsum = kqsum_shared[threadIdx.x];
171
- kqsum = warp_reduce_sum(kqsum);
172
 
173
- if (tid >= D) {
174
- return;
175
- }
 
176
 
177
- half dst_val = (__low2half(VKQ) + __high2half(VKQ));
178
- if (parallel_blocks == 1) {
179
- dst_val /= kqsum;
 
 
 
180
  }
181
- dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
182
 
183
- if (parallel_blocks == 1 || tid != 0) {
184
- return;
 
 
 
185
  }
186
- dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
187
  #else
188
  NO_DEVICE_CODE;
189
  #endif // FP16_AVAILABLE
@@ -191,7 +238,9 @@ static __global__ void flash_attn_vec_ext_f16(
191
 
192
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
193
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
 
194
  __launch_bounds__(nwarps*WARP_SIZE, 1)
 
195
  static __global__ void flash_attn_ext_f16(
196
  const char * __restrict__ Q,
197
  const char * __restrict__ K,
@@ -573,7 +622,9 @@ static __global__ void flash_attn_ext_f16(
573
  }
574
 
575
  template<int D, int parallel_blocks> // D == head size
 
576
  __launch_bounds__(D, 1)
 
577
  static __global__ void flash_attn_combine_results(
578
  const float * __restrict__ VKQ_parts,
579
  const float2 * __restrict__ VKQ_meta,
@@ -642,7 +693,7 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
642
  static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
643
  static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
644
 
645
- template <int D, int parallel_blocks> void launch_fattn_vec_f16(
646
  const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
647
  ggml_cuda_pool & pool, cudaStream_t main_stream
648
  ) {
@@ -656,13 +707,13 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
656
 
657
  constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
658
  const dim3 block_dim(WARP_SIZE, nwarps, 1);
659
- const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
660
  const int shmem = 0;
661
 
662
  float scale;
663
  memcpy(&scale, KQV->op_params, sizeof(float));
664
 
665
- flash_attn_vec_ext_f16<D, parallel_blocks>
666
  <<<blocks_num, block_dim, shmem, main_stream>>> (
667
  (const char *) Q->data,
668
  (const char *) K->data,
@@ -783,10 +834,99 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
783
 
784
  ggml_cuda_set_device(ctx.device);
785
 
 
786
  const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
787
 
788
  const int32_t precision = KQV->op_params[1];
789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  if (precision != GGML_PREC_DEFAULT) {
791
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
792
  constexpr int cols_per_block = 16;
@@ -845,16 +985,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
845
  }
846
 
847
  if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
 
848
  constexpr int parallel_blocks = 4;
849
  switch (Q->ne[0]) {
850
  case 64:
851
- launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
852
  break;
853
  case 128:
854
- launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
855
  break;
856
  case 256:
857
- launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
858
  break;
859
  default:
860
  GGML_ASSERT(false);
 
11
  #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
12
  #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
13
 
14
+ template<int D, int ncols, int parallel_blocks> // D == head size
15
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
16
+ __launch_bounds__(D, 1)
17
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
18
  static __global__ void flash_attn_vec_ext_f16(
19
  const char * __restrict__ Q,
20
  const char * __restrict__ K,
 
46
  #if FP16_AVAILABLE
47
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
48
 
49
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
50
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
51
 
52
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
53
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
54
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
55
  const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
56
+ const half * maskh = (const half *) mask + ne11*ic0;
57
 
58
  const int stride_KV = nb11 / sizeof(half);
59
  const int stride_KV2 = nb11 / sizeof(half2);
60
 
61
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
62
+ constexpr int nwarps = D / WARP_SIZE;
63
  const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
64
+ __builtin_assume(tid < D);
65
 
66
+ __shared__ half KQ[ncols*D];
67
+ #pragma unroll
68
+ for (int j = 0; j < ncols; ++j) {
69
+ KQ[j*D + tid] = -HALF_MAX_HALF;
70
+ }
71
  half2 * KQ2 = (half2 *) KQ;
72
 
73
+ half kqmax[ncols];
74
+ #pragma unroll
75
+ for (int j = 0; j < ncols; ++j) {
76
+ kqmax[j] = -HALF_MAX_HALF;
77
+ }
78
+ half kqsum[ncols] = {0.0f};
79
 
80
+ __shared__ half kqmax_shared[ncols][WARP_SIZE];
81
+ __shared__ half kqsum_shared[ncols][WARP_SIZE];
82
+ #pragma unroll
83
+ for (int j = 0; j < ncols; ++j) {
84
+ if (threadIdx.y == 0) {
85
+ kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
86
+ kqsum_shared[j][threadIdx.x] = 0.0f;
87
+ }
88
  }
89
  __syncthreads();
90
 
91
  // Convert Q to half2 and store in registers:
92
+ half2 Q_h2[ncols][D/(2*WARP_SIZE)];
93
  #pragma unroll
94
+ for (int j = 0; j < ncols; ++j) {
95
+ #pragma unroll
96
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
97
+ const int i = i0 + threadIdx.x;
 
98
 
99
+ const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
100
+ Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
101
+ }
102
  }
103
 
104
+ half2 VKQ[ncols] = {{0.0f, 0.0f}};
105
 
106
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
107
  for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
108
  // Calculate KQ tile and keep track of new maximum KQ values:
109
+
110
+ // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
111
+ // see https://github.com/ggerganov/llama.cpp/pull/7061 .
112
+ // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
113
+ half kqmax_new = kqmax[0];
114
+ half kqmax_new_arr[ncols];
115
+ #pragma unroll
116
+ for (int j = 0; j < ncols; ++j) {
117
+ kqmax_new_arr[j] = kqmax[j];
118
+ }
119
+
120
  #pragma unroll
121
  for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
122
  const int i_KQ = i_KQ_0 + threadIdx.y;
 
125
  break;
126
  }
127
 
128
+ half2 sum2[ncols] = {{0.0f, 0.0f}};
129
  #pragma unroll
130
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
131
  const int k_KQ = k_KQ_0 + threadIdx.x;
 
 
 
132
 
133
  const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
134
+ #pragma unroll
135
+ for (int j = 0; j < ncols; ++j) {
136
+ sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
137
+ }
138
  }
139
 
140
+ #pragma unroll
141
+ for (int j = 0; j < ncols; ++j) {
142
+ sum2[j] = warp_reduce_sum(sum2[j]);
143
+ half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
144
+ sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
145
+
146
+ if (ncols == 1) {
147
+ kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
148
+ } else {
149
+ kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
150
+ }
151
+
152
+ if (threadIdx.x == 0) {
153
+ KQ[j*D + i_KQ] = sum;
154
+ }
155
  }
156
  }
157
 
158
+ #pragma unroll
159
+ for (int j = 0; j < ncols; ++j) {
160
+ half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
161
+
162
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
163
+ if (threadIdx.x == 0) {
164
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
165
+ }
166
  }
167
+
168
  __syncthreads();
 
 
169
 
170
+ #pragma unroll
171
+ for (int j = 0; j < ncols; ++j) {
172
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
173
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
174
+
175
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
176
+ kqmax[j] = kqmax_new_j;
177
 
178
+ const half val = hexp(KQ[j*D + tid] - kqmax[j]);
179
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
180
+ KQ[j*D + tid] = val;
181
 
182
+ VKQ[j] *= __half2half2(KQ_max_scale);
183
+ }
184
 
185
  __syncthreads();
186
 
 
187
  #pragma unroll
188
+ for (int k0 = 0; k0 < D; k0 += 2) {
189
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
190
+ break;
191
+ }
192
 
193
+ half2 V_k;
194
+ reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
195
+ reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
196
+ #pragma unroll
197
+ for (int j = 0; j < ncols; ++j) {
198
+ VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
199
  }
200
  }
201
 
202
  __syncthreads();
203
  }
204
 
205
+ #pragma unroll
206
+ for (int j = 0; j < ncols; ++j) {
207
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
208
+ if (threadIdx.x == 0) {
209
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
210
+ }
211
  }
212
 
 
 
 
 
213
  __syncthreads();
 
 
214
 
215
+ #pragma unroll
216
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
217
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
218
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
219
 
220
+ half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
221
+ if (parallel_blocks == 1) {
222
+ dst_val /= kqsum[j_VKQ];
223
+ }
224
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
225
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
226
  }
 
227
 
228
+ if (parallel_blocks != 1 && tid != 0) {
229
+ #pragma unroll
230
+ for (int j = 0; j < ncols; ++j) {
231
+ dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
232
+ }
233
  }
 
234
  #else
235
  NO_DEVICE_CODE;
236
  #endif // FP16_AVAILABLE
 
238
 
239
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
240
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
241
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
242
  __launch_bounds__(nwarps*WARP_SIZE, 1)
243
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
244
  static __global__ void flash_attn_ext_f16(
245
  const char * __restrict__ Q,
246
  const char * __restrict__ K,
 
622
  }
623
 
624
  template<int D, int parallel_blocks> // D == head size
625
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
626
  __launch_bounds__(D, 1)
627
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
628
  static __global__ void flash_attn_combine_results(
629
  const float * __restrict__ VKQ_parts,
630
  const float2 * __restrict__ VKQ_meta,
 
693
  static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
694
  static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
695
 
696
+ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
697
  const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
698
  ggml_cuda_pool & pool, cudaStream_t main_stream
699
  ) {
 
707
 
708
  constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
709
  const dim3 block_dim(WARP_SIZE, nwarps, 1);
710
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
711
  const int shmem = 0;
712
 
713
  float scale;
714
  memcpy(&scale, KQV->op_params, sizeof(float));
715
 
716
+ flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
717
  <<<blocks_num, block_dim, shmem, main_stream>>> (
718
  (const char *) Q->data,
719
  (const char *) K->data,
 
834
 
835
  ggml_cuda_set_device(ctx.device);
836
 
837
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
838
  const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
839
 
840
  const int32_t precision = KQV->op_params[1];
841
 
842
+ if (!fp16_mma_available(cc)) {
843
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
844
+ GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
845
+
846
+ if (Q->ne[1] == 1) {
847
+ constexpr int cols_per_block = 1;
848
+ constexpr int parallel_blocks = 4;
849
+ switch (Q->ne[0]) {
850
+ case 64:
851
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
852
+ break;
853
+ case 128:
854
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
855
+ break;
856
+ default:
857
+ GGML_ASSERT(false);
858
+ break;
859
+ }
860
+ return;
861
+ }
862
+
863
+ if (Q->ne[1] == 2) {
864
+ constexpr int cols_per_block = 2;
865
+ constexpr int parallel_blocks = 4;
866
+ switch (Q->ne[0]) {
867
+ case 64:
868
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
869
+ break;
870
+ case 128:
871
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
872
+ break;
873
+ default:
874
+ GGML_ASSERT(false);
875
+ break;
876
+ }
877
+ return;
878
+ }
879
+
880
+ if (Q->ne[1] <= 4) {
881
+ constexpr int cols_per_block = 4;
882
+ constexpr int parallel_blocks = 4;
883
+ switch (Q->ne[0]) {
884
+ case 64:
885
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
886
+ break;
887
+ case 128:
888
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
889
+ break;
890
+ default:
891
+ GGML_ASSERT(false);
892
+ break;
893
+ }
894
+ return;
895
+ }
896
+
897
+ if (Q->ne[1] <= 8) {
898
+ constexpr int cols_per_block = 8;
899
+ constexpr int parallel_blocks = 4;
900
+ switch (Q->ne[0]) {
901
+ case 64:
902
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
903
+ break;
904
+ case 128:
905
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
906
+ break;
907
+ default:
908
+ GGML_ASSERT(false);
909
+ break;
910
+ }
911
+ return;
912
+ }
913
+
914
+ constexpr int cols_per_block = 8;
915
+ constexpr int parallel_blocks = 1;
916
+ switch (Q->ne[0]) {
917
+ case 64:
918
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
919
+ break;
920
+ case 128:
921
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
922
+ break;
923
+ default:
924
+ GGML_ASSERT(false);
925
+ break;
926
+ }
927
+ return;
928
+ }
929
+
930
  if (precision != GGML_PREC_DEFAULT) {
931
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
932
  constexpr int cols_per_block = 16;
 
985
  }
986
 
987
  if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
988
+ constexpr int cols_per_block = 1;
989
  constexpr int parallel_blocks = 4;
990
  switch (Q->ne[0]) {
991
  case 64:
992
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
993
  break;
994
  case 128:
995
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
996
  break;
997
  case 256:
998
+ launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
999
  break;
1000
  default:
1001
  GGML_ASSERT(false);