JohannesGaessler commited on
Commit
03d4b22
·
unverified ·
1 Parent(s): 7b58c58

CUDA: add FP32 FlashAttention vector kernel (llama/7188)

Browse files

* CUDA: add FP32 FlashAttention vector kernel

* fixup! CUDA: add FP32 FlashAttention vector kernel

* fixup! fixup! CUDA: add FP32 FlashAttention vector kernel

* fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel

ggml-cuda.cu CHANGED
@@ -2713,6 +2713,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2713
  }
2714
 
2715
  GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
 
2716
  switch (op->op) {
2717
  case GGML_OP_UNARY:
2718
  switch (ggml_get_unary_op(op)) {
@@ -2840,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2840
  case GGML_OP_ARANGE:
2841
  case GGML_OP_TIMESTEP_EMBEDDING:
2842
  case GGML_OP_LEAKY_RELU:
2843
- case GGML_OP_FLASH_ATTN_EXT:
2844
  return true;
 
 
 
 
 
 
 
 
 
2845
  default:
2846
  return false;
2847
  }
 
2713
  }
2714
 
2715
  GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
2716
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2717
  switch (op->op) {
2718
  case GGML_OP_UNARY:
2719
  switch (ggml_get_unary_op(op)) {
 
2841
  case GGML_OP_ARANGE:
2842
  case GGML_OP_TIMESTEP_EMBEDDING:
2843
  case GGML_OP_LEAKY_RELU:
 
2844
  return true;
2845
+ case GGML_OP_FLASH_ATTN_EXT:
2846
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2847
+ return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2848
+ #else
2849
+ if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2850
+ return true;
2851
+ }
2852
+ return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2853
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2854
  default:
2855
  return false;
2856
  }
ggml-cuda/common.cuh CHANGED
@@ -321,6 +321,10 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
321
 
322
  #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
323
 
 
 
 
 
324
  static bool fp16_mma_available(const int cc) {
325
  return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
326
  }
 
321
 
322
  #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
323
 
324
+ static bool fast_fp16_available(const int cc) {
325
+ return cc >= CC_PASCAL && cc != 610;
326
+ }
327
+
328
  static bool fp16_mma_available(const int cc) {
329
  return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
330
  }
ggml-cuda/fattn-common.cuh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define FATTN_KQ_STRIDE 256
2
+ #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
3
+ #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
4
+
5
+ template<int D, int parallel_blocks> // D == head size
6
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7
+ __launch_bounds__(D, 1)
8
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ static __global__ void flash_attn_combine_results(
10
+ const float * __restrict__ VKQ_parts,
11
+ const float2 * __restrict__ VKQ_meta,
12
+ float * __restrict__ dst) {
13
+ VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
14
+ VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
15
+ dst += D * gridDim.y*blockIdx.x;
16
+
17
+ const int tid = threadIdx.x;
18
+ __builtin_assume(tid < D);
19
+
20
+ __shared__ float2 meta[parallel_blocks];
21
+ if (tid < 2*parallel_blocks) {
22
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
23
+ }
24
+
25
+ __syncthreads();
26
+
27
+ float kqmax = meta[0].x;
28
+ #pragma unroll
29
+ for (int l = 1; l < parallel_blocks; ++l) {
30
+ kqmax = max(kqmax, meta[l].x);
31
+ }
32
+
33
+ float VKQ_numerator = 0.0f;
34
+ float VKQ_denominator = 0.0f;
35
+ #pragma unroll
36
+ for (int l = 0; l < parallel_blocks; ++l) {
37
+ const float diff = meta[l].x - kqmax;
38
+ const float KQ_max_scale = expf(diff);
39
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
40
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
41
+
42
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
43
+ VKQ_denominator += KQ_max_scale * meta[l].y;
44
+ }
45
+
46
+ dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
47
+ }
ggml-cuda/fattn-vec-f16.cu ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-vec-f16.cuh"
4
+
5
+ template<int D, int ncols, int parallel_blocks> // D == head size
6
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7
+ __launch_bounds__(D, 1)
8
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ static __global__ void flash_attn_vec_ext_f16(
10
+ const char * __restrict__ Q,
11
+ const char * __restrict__ K,
12
+ const char * __restrict__ V,
13
+ const char * __restrict__ mask,
14
+ float * __restrict__ dst,
15
+ float2 * __restrict__ dst_meta,
16
+ const float scale,
17
+ const float max_bias,
18
+ const float m0,
19
+ const float m1,
20
+ const uint32_t n_head_log2,
21
+ const int ne00,
22
+ const int ne01,
23
+ const int ne02,
24
+ const int ne03,
25
+ const int ne10,
26
+ const int ne11,
27
+ const int ne12,
28
+ const int ne13,
29
+ const int ne31,
30
+ const int nb31,
31
+ const int nb01,
32
+ const int nb02,
33
+ const int nb03,
34
+ const int nb11,
35
+ const int nb12,
36
+ const int nb13,
37
+ const int ne0,
38
+ const int ne1,
39
+ const int ne2,
40
+ const int ne3) {
41
+ #if FP16_AVAILABLE
42
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
43
+
44
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
45
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
46
+
47
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
48
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
49
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
50
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
51
+ const half * maskh = (const half *) mask + ne11*ic0;
52
+
53
+ const int stride_KV = nb11 / sizeof(half);
54
+ const int stride_KV2 = nb11 / sizeof(half2);
55
+
56
+ half slopeh = __float2half(1.0f);
57
+
58
+ // ALiBi
59
+ if (max_bias > 0.0f) {
60
+ const int h = blockIdx.y;
61
+
62
+ const float base = h < n_head_log2 ? m0 : m1;
63
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
64
+
65
+ slopeh = __float2half(powf(base, exph));
66
+ }
67
+
68
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
69
+ constexpr int nwarps = D / WARP_SIZE;
70
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
71
+ __builtin_assume(tid < D);
72
+
73
+ __shared__ half KQ[ncols*D];
74
+ #pragma unroll
75
+ for (int j = 0; j < ncols; ++j) {
76
+ KQ[j*D + tid] = -HALF_MAX_HALF;
77
+ }
78
+ half2 * KQ2 = (half2 *) KQ;
79
+
80
+ half kqmax[ncols];
81
+ #pragma unroll
82
+ for (int j = 0; j < ncols; ++j) {
83
+ kqmax[j] = -HALF_MAX_HALF;
84
+ }
85
+ half kqsum[ncols] = {0.0f};
86
+
87
+ __shared__ half kqmax_shared[ncols][WARP_SIZE];
88
+ __shared__ half kqsum_shared[ncols][WARP_SIZE];
89
+ #pragma unroll
90
+ for (int j = 0; j < ncols; ++j) {
91
+ if (threadIdx.y == 0) {
92
+ kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
93
+ kqsum_shared[j][threadIdx.x] = 0.0f;
94
+ }
95
+ }
96
+ __syncthreads();
97
+
98
+ // Convert Q to half2 and store in registers:
99
+ half2 Q_h2[ncols][D/(2*WARP_SIZE)];
100
+ #pragma unroll
101
+ for (int j = 0; j < ncols; ++j) {
102
+ #pragma unroll
103
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
104
+ const int i = i0 + threadIdx.x;
105
+
106
+ const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
107
+ Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
108
+ }
109
+ }
110
+
111
+ half2 VKQ[ncols] = {{0.0f, 0.0f}};
112
+
113
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
114
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
115
+ // Calculate KQ tile and keep track of new maximum KQ values:
116
+
117
+ // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
118
+ // see https://github.com/ggerganov/llama.cpp/pull/7061 .
119
+ // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
120
+ half kqmax_new = kqmax[0];
121
+ half kqmax_new_arr[ncols];
122
+ #pragma unroll
123
+ for (int j = 0; j < ncols; ++j) {
124
+ kqmax_new_arr[j] = kqmax[j];
125
+ }
126
+
127
+ #pragma unroll
128
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
129
+ const int i_KQ = i_KQ_0 + threadIdx.y;
130
+
131
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
132
+ break;
133
+ }
134
+
135
+ half2 sum2[ncols] = {{0.0f, 0.0f}};
136
+ #pragma unroll
137
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
138
+ const int k_KQ = k_KQ_0 + threadIdx.x;
139
+
140
+ const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
141
+ #pragma unroll
142
+ for (int j = 0; j < ncols; ++j) {
143
+ sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
144
+ }
145
+ }
146
+
147
+ #pragma unroll
148
+ for (int j = 0; j < ncols; ++j) {
149
+ sum2[j] = warp_reduce_sum(sum2[j]);
150
+ half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
151
+ sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
152
+
153
+ if (ncols == 1) {
154
+ kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
155
+ } else {
156
+ kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
157
+ }
158
+
159
+ if (threadIdx.x == 0) {
160
+ KQ[j*D + i_KQ] = sum;
161
+ }
162
+ }
163
+ }
164
+
165
+ #pragma unroll
166
+ for (int j = 0; j < ncols; ++j) {
167
+ half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
168
+
169
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
170
+ if (threadIdx.x == 0) {
171
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
172
+ }
173
+ }
174
+
175
+ __syncthreads();
176
+
177
+ #pragma unroll
178
+ for (int j = 0; j < ncols; ++j) {
179
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
180
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
181
+
182
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
183
+ kqmax[j] = kqmax_new_j;
184
+
185
+ const half val = hexp(KQ[j*D + tid] - kqmax[j]);
186
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
187
+ KQ[j*D + tid] = val;
188
+
189
+ VKQ[j] *= __half2half2(KQ_max_scale);
190
+ }
191
+
192
+ __syncthreads();
193
+
194
+ #pragma unroll
195
+ for (int k0 = 0; k0 < D; k0 += 2) {
196
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
197
+ break;
198
+ }
199
+
200
+ half2 V_k;
201
+ reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
202
+ reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
203
+ #pragma unroll
204
+ for (int j = 0; j < ncols; ++j) {
205
+ VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
206
+ }
207
+ }
208
+
209
+ __syncthreads();
210
+ }
211
+
212
+ #pragma unroll
213
+ for (int j = 0; j < ncols; ++j) {
214
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
215
+ if (threadIdx.x == 0) {
216
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
217
+ }
218
+ }
219
+
220
+ __syncthreads();
221
+
222
+ #pragma unroll
223
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
224
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
225
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
226
+
227
+ half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
228
+ if (parallel_blocks == 1) {
229
+ dst_val /= kqsum[j_VKQ];
230
+ }
231
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
232
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
233
+ }
234
+
235
+ if (parallel_blocks != 1 && tid != 0) {
236
+ #pragma unroll
237
+ for (int j = 0; j < ncols; ++j) {
238
+ dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
239
+ }
240
+ }
241
+ #else
242
+ NO_DEVICE_CODE;
243
+ #endif // FP16_AVAILABLE
244
+ }
245
+
246
+ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
247
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
248
+ ggml_cuda_pool & pool, cudaStream_t main_stream
249
+ ) {
250
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
251
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
252
+
253
+ if (parallel_blocks > 1) {
254
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
255
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
256
+ }
257
+
258
+ constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
259
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
260
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
261
+ const int shmem = 0;
262
+
263
+ float scale = 1.0f;
264
+ float max_bias = 0.0f;
265
+
266
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
267
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
268
+
269
+ const uint32_t n_head = Q->ne[2];
270
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
271
+
272
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
273
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
274
+
275
+ flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
276
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
277
+ (const char *) Q->data,
278
+ (const char *) K->data,
279
+ (const char *) V->data,
280
+ mask ? ((const char *) mask->data) : nullptr,
281
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
282
+ scale, max_bias, m0, m1, n_head_log2,
283
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
284
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
285
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
286
+ Q->nb[1], Q->nb[2], Q->nb[3],
287
+ K->nb[1], K->nb[2], K->nb[3],
288
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
289
+ );
290
+ CUDA_CHECK(cudaGetLastError());
291
+
292
+ if (parallel_blocks == 1) {
293
+ return;
294
+ }
295
+
296
+ const dim3 block_dim_combine(D, 1, 1);
297
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
298
+ const int shmem_combine = 0;
299
+
300
+ flash_attn_combine_results<D, parallel_blocks>
301
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
302
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
303
+ CUDA_CHECK(cudaGetLastError());
304
+ }
305
+
306
+ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
307
+ const ggml_tensor * Q = dst->src[0];
308
+ const ggml_tensor * K = dst->src[1];
309
+ const ggml_tensor * V = dst->src[2];
310
+
311
+ const ggml_tensor * mask = dst->src[3];
312
+
313
+ ggml_tensor * KQV = dst;
314
+
315
+ const int32_t precision = KQV->op_params[2];
316
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
317
+
318
+ constexpr int cols_per_block = 1;
319
+ constexpr int parallel_blocks = 4;
320
+ switch (Q->ne[0]) {
321
+ case 64:
322
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
323
+ break;
324
+ case 128:
325
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
326
+ break;
327
+ case 256:
328
+ launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
329
+ break;
330
+ default:
331
+ GGML_ASSERT(false);
332
+ break;
333
+ }
334
+ }
335
+
336
+ void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
337
+ const ggml_tensor * Q = dst->src[0];
338
+ const ggml_tensor * K = dst->src[1];
339
+ const ggml_tensor * V = dst->src[2];
340
+
341
+ const ggml_tensor * mask = dst->src[3];
342
+
343
+ ggml_tensor * KQV = dst;
344
+
345
+ const int32_t precision = KQV->op_params[2];
346
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
347
+ GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
348
+
349
+ if (Q->ne[1] == 1) {
350
+ constexpr int cols_per_block = 1;
351
+ constexpr int parallel_blocks = 4;
352
+ switch (Q->ne[0]) {
353
+ case 64:
354
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
355
+ break;
356
+ case 128:
357
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
358
+ break;
359
+ default:
360
+ GGML_ASSERT(false);
361
+ break;
362
+ }
363
+ return;
364
+ }
365
+
366
+ if (Q->ne[1] == 2) {
367
+ constexpr int cols_per_block = 2;
368
+ constexpr int parallel_blocks = 4;
369
+ switch (Q->ne[0]) {
370
+ case 64:
371
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
372
+ break;
373
+ case 128:
374
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
375
+ break;
376
+ default:
377
+ GGML_ASSERT(false);
378
+ break;
379
+ }
380
+ return;
381
+ }
382
+
383
+ if (Q->ne[1] <= 4) {
384
+ constexpr int cols_per_block = 4;
385
+ constexpr int parallel_blocks = 4;
386
+ switch (Q->ne[0]) {
387
+ case 64:
388
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
389
+ break;
390
+ case 128:
391
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
392
+ break;
393
+ default:
394
+ GGML_ASSERT(false);
395
+ break;
396
+ }
397
+ return;
398
+ }
399
+
400
+ if (Q->ne[1] <= 8) {
401
+ constexpr int cols_per_block = 8;
402
+ constexpr int parallel_blocks = 4;
403
+ switch (Q->ne[0]) {
404
+ case 64:
405
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
406
+ break;
407
+ case 128:
408
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
409
+ break;
410
+ default:
411
+ GGML_ASSERT(false);
412
+ break;
413
+ }
414
+ return;
415
+ }
416
+
417
+ constexpr int cols_per_block = 8;
418
+ constexpr int parallel_blocks = 1;
419
+ switch (Q->ne[0]) {
420
+ case 64:
421
+ launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
422
+ break;
423
+ case 128:
424
+ launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
425
+ break;
426
+ default:
427
+ GGML_ASSERT(false);
428
+ break;
429
+ }
430
+ }
ggml-cuda/fattn-vec-f16.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4
+
5
+ void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml-cuda/fattn-vec-f32.cu ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-vec-f32.cuh"
4
+
5
+ template<int D, int ncols, int parallel_blocks> // D == head size
6
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7
+ __launch_bounds__(D, 1)
8
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ static __global__ void flash_attn_vec_ext_f32(
10
+ const char * __restrict__ Q,
11
+ const char * __restrict__ K,
12
+ const char * __restrict__ V,
13
+ const char * __restrict__ mask,
14
+ float * __restrict__ dst,
15
+ float2 * __restrict__ dst_meta,
16
+ const float scale,
17
+ const float max_bias,
18
+ const float m0,
19
+ const float m1,
20
+ const uint32_t n_head_log2,
21
+ const int ne00,
22
+ const int ne01,
23
+ const int ne02,
24
+ const int ne03,
25
+ const int ne10,
26
+ const int ne11,
27
+ const int ne12,
28
+ const int ne13,
29
+ const int ne31,
30
+ const int nb31,
31
+ const int nb01,
32
+ const int nb02,
33
+ const int nb03,
34
+ const int nb11,
35
+ const int nb12,
36
+ const int nb13,
37
+ const int ne0,
38
+ const int ne1,
39
+ const int ne2,
40
+ const int ne3) {
41
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
42
+
43
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
44
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
45
+
46
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
47
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
48
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
49
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
50
+ const half * maskh = (const half *) mask + ne11*ic0;
51
+
52
+ const int stride_KV = nb11 / sizeof(half);
53
+ const int stride_KV2 = nb11 / sizeof(half2);
54
+
55
+ float slope = 1.0f;
56
+
57
+ // ALiBi
58
+ if (max_bias > 0.0f) {
59
+ const int h = blockIdx.y;
60
+
61
+ const float base = h < n_head_log2 ? m0 : m1;
62
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
63
+
64
+ slope = powf(base, exph);
65
+ }
66
+
67
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
68
+ constexpr int nwarps = D / WARP_SIZE;
69
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
70
+ __builtin_assume(tid < D);
71
+
72
+ __shared__ float KQ[ncols*D];
73
+ #pragma unroll
74
+ for (int j = 0; j < ncols; ++j) {
75
+ KQ[j*D + tid] = -FLT_MAX/2.0f;
76
+ }
77
+
78
+ float kqmax[ncols];
79
+ #pragma unroll
80
+ for (int j = 0; j < ncols; ++j) {
81
+ kqmax[j] = -FLT_MAX/2.0f;
82
+ }
83
+ float kqsum[ncols] = {0.0f};
84
+
85
+ __shared__ float kqmax_shared[ncols][WARP_SIZE];
86
+ __shared__ float kqsum_shared[ncols][WARP_SIZE];
87
+ #pragma unroll
88
+ for (int j = 0; j < ncols; ++j) {
89
+ if (threadIdx.y == 0) {
90
+ kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
91
+ kqsum_shared[j][threadIdx.x] = 0.0f;
92
+ }
93
+ }
94
+ __syncthreads();
95
+
96
+ // Convert Q to half2 and store in registers:
97
+ float2 Q_h2[ncols][D/(2*WARP_SIZE)];
98
+ #pragma unroll
99
+ for (int j = 0; j < ncols; ++j) {
100
+ #pragma unroll
101
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
102
+ const int i = i0 + threadIdx.x;
103
+
104
+ Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i];
105
+ Q_h2[j][i0/WARP_SIZE].x *= scale;
106
+ Q_h2[j][i0/WARP_SIZE].y *= scale;
107
+ }
108
+ }
109
+
110
+ float VKQ[ncols] = {0.0f};
111
+
112
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
113
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
114
+ // Calculate KQ tile and keep track of new maximum KQ values:
115
+
116
+ float kqmax_new_arr[ncols];
117
+ #pragma unroll
118
+ for (int j = 0; j < ncols; ++j) {
119
+ kqmax_new_arr[j] = kqmax[j];
120
+ }
121
+
122
+ #pragma unroll
123
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
124
+ const int i_KQ = i_KQ_0 + threadIdx.y;
125
+
126
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
127
+ break;
128
+ }
129
+
130
+ float sum[ncols] = {0.0f};
131
+ #pragma unroll
132
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
133
+ const int k_KQ = k_KQ_0 + threadIdx.x;
134
+
135
+ const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
136
+ #pragma unroll
137
+ for (int j = 0; j < ncols; ++j) {
138
+ sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x;
139
+ sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y;
140
+ }
141
+ }
142
+
143
+ #pragma unroll
144
+ for (int j = 0; j < ncols; ++j) {
145
+ sum[j] = warp_reduce_sum(sum[j]);
146
+ sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
147
+
148
+ kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
149
+
150
+ if (threadIdx.x == 0) {
151
+ KQ[j*D + i_KQ] = sum[j];
152
+ }
153
+ }
154
+ }
155
+
156
+ #pragma unroll
157
+ for (int j = 0; j < ncols; ++j) {
158
+ float kqmax_new_j = kqmax_new_arr[j];
159
+
160
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
161
+ if (threadIdx.x == 0) {
162
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
163
+ }
164
+ }
165
+
166
+ __syncthreads();
167
+
168
+ #pragma unroll
169
+ for (int j = 0; j < ncols; ++j) {
170
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
171
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
172
+
173
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
174
+ kqmax[j] = kqmax_new_j;
175
+
176
+ const float val = expf(KQ[j*D + tid] - kqmax[j]);
177
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
178
+ KQ[j*D + tid] = val;
179
+
180
+ VKQ[j] *= KQ_max_scale;
181
+ }
182
+
183
+ __syncthreads();
184
+
185
+ #pragma unroll
186
+ for (int k = 0; k < D; ++k) {
187
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
188
+ break;
189
+ }
190
+
191
+ const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]);
192
+ #pragma unroll
193
+ for (int j = 0; j < ncols; ++j) {
194
+ VKQ[j] += V_ki*KQ[j*D + k];
195
+ }
196
+ }
197
+
198
+ __syncthreads();
199
+ }
200
+
201
+ #pragma unroll
202
+ for (int j = 0; j < ncols; ++j) {
203
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
204
+ if (threadIdx.x == 0) {
205
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
206
+ }
207
+ }
208
+
209
+ __syncthreads();
210
+
211
+ #pragma unroll
212
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
213
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
214
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
215
+
216
+ float dst_val = VKQ[j_VKQ];
217
+ if (parallel_blocks == 1) {
218
+ dst_val /= kqsum[j_VKQ];
219
+ }
220
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
221
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
222
+ }
223
+
224
+ if (parallel_blocks != 1 && tid != 0) {
225
+ #pragma unroll
226
+ for (int j = 0; j < ncols; ++j) {
227
+ dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
228
+ }
229
+ }
230
+ }
231
+
232
+ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f32(
233
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
234
+ ggml_cuda_pool & pool, cudaStream_t main_stream
235
+ ) {
236
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
237
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
238
+
239
+ if (parallel_blocks > 1) {
240
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
241
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
242
+ }
243
+
244
+ constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
245
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
246
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
247
+ const int shmem = 0;
248
+
249
+ float scale = 1.0f;
250
+ float max_bias = 0.0f;
251
+
252
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
253
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
254
+
255
+ const uint32_t n_head = Q->ne[2];
256
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
257
+
258
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
259
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
260
+
261
+ flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
262
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
263
+ (const char *) Q->data,
264
+ (const char *) K->data,
265
+ (const char *) V->data,
266
+ mask ? ((const char *) mask->data) : nullptr,
267
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
268
+ scale, max_bias, m0, m1, n_head_log2,
269
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
270
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
271
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
272
+ Q->nb[1], Q->nb[2], Q->nb[3],
273
+ K->nb[1], K->nb[2], K->nb[3],
274
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
275
+ );
276
+ CUDA_CHECK(cudaGetLastError());
277
+
278
+ if (parallel_blocks == 1) {
279
+ return;
280
+ }
281
+
282
+ const dim3 block_dim_combine(D, 1, 1);
283
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
284
+ const int shmem_combine = 0;
285
+
286
+ flash_attn_combine_results<D, parallel_blocks>
287
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
288
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
289
+ CUDA_CHECK(cudaGetLastError());
290
+ }
291
+
292
+ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293
+ const ggml_tensor * Q = dst->src[0];
294
+ const ggml_tensor * K = dst->src[1];
295
+ const ggml_tensor * V = dst->src[2];
296
+
297
+ const ggml_tensor * mask = dst->src[3];
298
+
299
+ ggml_tensor * KQV = dst;
300
+
301
+ GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
302
+
303
+ if (Q->ne[1] == 1) {
304
+ constexpr int cols_per_block = 1;
305
+ constexpr int parallel_blocks = 4;
306
+ switch (Q->ne[0]) {
307
+ case 64:
308
+ launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
309
+ break;
310
+ case 128:
311
+ launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
312
+ break;
313
+ default:
314
+ GGML_ASSERT(false);
315
+ break;
316
+ }
317
+ return;
318
+ }
319
+
320
+ if (Q->ne[1] == 2) {
321
+ constexpr int cols_per_block = 2;
322
+ constexpr int parallel_blocks = 4;
323
+ switch (Q->ne[0]) {
324
+ case 64:
325
+ launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
326
+ break;
327
+ case 128:
328
+ launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
329
+ break;
330
+ default:
331
+ GGML_ASSERT(false);
332
+ break;
333
+ }
334
+ return;
335
+ }
336
+
337
+ if (Q->ne[1] <= 4) {
338
+ constexpr int cols_per_block = 4;
339
+ constexpr int parallel_blocks = 4;
340
+ switch (Q->ne[0]) {
341
+ case 64:
342
+ launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
343
+ break;
344
+ case 128:
345
+ launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
346
+ break;
347
+ default:
348
+ GGML_ASSERT(false);
349
+ break;
350
+ }
351
+ return;
352
+ }
353
+
354
+ if (Q->ne[1] <= 8) {
355
+ constexpr int cols_per_block = 8;
356
+ constexpr int parallel_blocks = 4;
357
+ switch (Q->ne[0]) {
358
+ case 64:
359
+ launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
360
+ break;
361
+ case 128:
362
+ launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
363
+ break;
364
+ default:
365
+ GGML_ASSERT(false);
366
+ break;
367
+ }
368
+ return;
369
+ }
370
+
371
+ constexpr int cols_per_block = 8;
372
+ constexpr int parallel_blocks = 1;
373
+ switch (Q->ne[0]) {
374
+ case 64:
375
+ launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
376
+ break;
377
+ case 128:
378
+ launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
379
+ break;
380
+ default:
381
+ GGML_ASSERT(false);
382
+ break;
383
+ }
384
+ }
ggml-cuda/fattn-vec-f32.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml-cuda/fattn.cu CHANGED
@@ -1,4 +1,7 @@
1
  #include "common.cuh"
 
 
 
2
  #include "fattn.cuh"
3
 
4
  #include <cstdint>
@@ -7,251 +10,6 @@
7
  #include <mma.h>
8
  #endif
9
 
10
- #define FATTN_KQ_STRIDE 256
11
- #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
12
- #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
13
-
14
- template<int D, int ncols, int parallel_blocks> // D == head size
15
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
16
- __launch_bounds__(D, 1)
17
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
18
- static __global__ void flash_attn_vec_ext_f16(
19
- const char * __restrict__ Q,
20
- const char * __restrict__ K,
21
- const char * __restrict__ V,
22
- const char * __restrict__ mask,
23
- float * __restrict__ dst,
24
- float2 * __restrict__ dst_meta,
25
- const float scale,
26
- const float max_bias,
27
- const float m0,
28
- const float m1,
29
- const uint32_t n_head_log2,
30
- const int ne00,
31
- const int ne01,
32
- const int ne02,
33
- const int ne03,
34
- const int ne10,
35
- const int ne11,
36
- const int ne12,
37
- const int ne13,
38
- const int ne31,
39
- const int nb31,
40
- const int nb01,
41
- const int nb02,
42
- const int nb03,
43
- const int nb11,
44
- const int nb12,
45
- const int nb13,
46
- const int ne0,
47
- const int ne1,
48
- const int ne2,
49
- const int ne3) {
50
- #if FP16_AVAILABLE
51
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
52
-
53
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
54
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
55
-
56
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
57
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
58
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
59
- const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
60
- const half * maskh = (const half *) mask + ne11*ic0;
61
-
62
- const int stride_KV = nb11 / sizeof(half);
63
- const int stride_KV2 = nb11 / sizeof(half2);
64
-
65
- half slopeh = __float2half(1.0f);
66
-
67
- // ALiBi
68
- if (max_bias > 0.0f) {
69
- const int h = blockIdx.y;
70
-
71
- const float base = h < n_head_log2 ? m0 : m1;
72
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
73
-
74
- slopeh = __float2half(powf(base, exph));
75
- }
76
-
77
- static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
78
- constexpr int nwarps = D / WARP_SIZE;
79
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
80
- __builtin_assume(tid < D);
81
-
82
- __shared__ half KQ[ncols*D];
83
- #pragma unroll
84
- for (int j = 0; j < ncols; ++j) {
85
- KQ[j*D + tid] = -HALF_MAX_HALF;
86
- }
87
- half2 * KQ2 = (half2 *) KQ;
88
-
89
- half kqmax[ncols];
90
- #pragma unroll
91
- for (int j = 0; j < ncols; ++j) {
92
- kqmax[j] = -HALF_MAX_HALF;
93
- }
94
- half kqsum[ncols] = {0.0f};
95
-
96
- __shared__ half kqmax_shared[ncols][WARP_SIZE];
97
- __shared__ half kqsum_shared[ncols][WARP_SIZE];
98
- #pragma unroll
99
- for (int j = 0; j < ncols; ++j) {
100
- if (threadIdx.y == 0) {
101
- kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
102
- kqsum_shared[j][threadIdx.x] = 0.0f;
103
- }
104
- }
105
- __syncthreads();
106
-
107
- // Convert Q to half2 and store in registers:
108
- half2 Q_h2[ncols][D/(2*WARP_SIZE)];
109
- #pragma unroll
110
- for (int j = 0; j < ncols; ++j) {
111
- #pragma unroll
112
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
113
- const int i = i0 + threadIdx.x;
114
-
115
- const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
116
- Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
117
- }
118
- }
119
-
120
- half2 VKQ[ncols] = {{0.0f, 0.0f}};
121
-
122
- const int k_start = parallel_blocks == 1 ? 0 : ip*D;
123
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
124
- // Calculate KQ tile and keep track of new maximum KQ values:
125
-
126
- // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
127
- // see https://github.com/ggerganov/llama.cpp/pull/7061 .
128
- // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
129
- half kqmax_new = kqmax[0];
130
- half kqmax_new_arr[ncols];
131
- #pragma unroll
132
- for (int j = 0; j < ncols; ++j) {
133
- kqmax_new_arr[j] = kqmax[j];
134
- }
135
-
136
- #pragma unroll
137
- for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
138
- const int i_KQ = i_KQ_0 + threadIdx.y;
139
-
140
- if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
141
- break;
142
- }
143
-
144
- half2 sum2[ncols] = {{0.0f, 0.0f}};
145
- #pragma unroll
146
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
147
- const int k_KQ = k_KQ_0 + threadIdx.x;
148
-
149
- const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
150
- #pragma unroll
151
- for (int j = 0; j < ncols; ++j) {
152
- sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
153
- }
154
- }
155
-
156
- #pragma unroll
157
- for (int j = 0; j < ncols; ++j) {
158
- sum2[j] = warp_reduce_sum(sum2[j]);
159
- half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
160
- sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
161
-
162
- if (ncols == 1) {
163
- kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
164
- } else {
165
- kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
166
- }
167
-
168
- if (threadIdx.x == 0) {
169
- KQ[j*D + i_KQ] = sum;
170
- }
171
- }
172
- }
173
-
174
- #pragma unroll
175
- for (int j = 0; j < ncols; ++j) {
176
- half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
177
-
178
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
179
- if (threadIdx.x == 0) {
180
- kqmax_shared[j][threadIdx.y] = kqmax_new_j;
181
- }
182
- }
183
-
184
- __syncthreads();
185
-
186
- #pragma unroll
187
- for (int j = 0; j < ncols; ++j) {
188
- half kqmax_new_j = kqmax_shared[j][threadIdx.x];
189
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
190
-
191
- const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
192
- kqmax[j] = kqmax_new_j;
193
-
194
- const half val = hexp(KQ[j*D + tid] - kqmax[j]);
195
- kqsum[j] = kqsum[j]*KQ_max_scale + val;
196
- KQ[j*D + tid] = val;
197
-
198
- VKQ[j] *= __half2half2(KQ_max_scale);
199
- }
200
-
201
- __syncthreads();
202
-
203
- #pragma unroll
204
- for (int k0 = 0; k0 < D; k0 += 2) {
205
- if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
206
- break;
207
- }
208
-
209
- half2 V_k;
210
- reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
211
- reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
212
- #pragma unroll
213
- for (int j = 0; j < ncols; ++j) {
214
- VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
215
- }
216
- }
217
-
218
- __syncthreads();
219
- }
220
-
221
- #pragma unroll
222
- for (int j = 0; j < ncols; ++j) {
223
- kqsum[j] = warp_reduce_sum(kqsum[j]);
224
- if (threadIdx.x == 0) {
225
- kqsum_shared[j][threadIdx.y] = kqsum[j];
226
- }
227
- }
228
-
229
- __syncthreads();
230
-
231
- #pragma unroll
232
- for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
233
- kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
234
- kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
235
-
236
- half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
237
- if (parallel_blocks == 1) {
238
- dst_val /= kqsum[j_VKQ];
239
- }
240
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
241
- dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
242
- }
243
-
244
- if (parallel_blocks != 1 && tid != 0) {
245
- #pragma unroll
246
- for (int j = 0; j < ncols; ++j) {
247
- dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
248
- }
249
- }
250
- #else
251
- NO_DEVICE_CODE;
252
- #endif // FP16_AVAILABLE
253
- }
254
-
255
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
256
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
257
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -655,54 +413,6 @@ static __global__ void flash_attn_ext_f16(
655
  #endif // FP16_MMA_AVAILABLE
656
  }
657
 
658
- template<int D, int parallel_blocks> // D == head size
659
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
660
- __launch_bounds__(D, 1)
661
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
662
- static __global__ void flash_attn_combine_results(
663
- const float * __restrict__ VKQ_parts,
664
- const float2 * __restrict__ VKQ_meta,
665
- float * __restrict__ dst) {
666
- #if FP16_AVAILABLE
667
- VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
668
- VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
669
- dst += D * gridDim.y*blockIdx.x;
670
-
671
- const int tid = threadIdx.x;
672
- __builtin_assume(tid < D);
673
-
674
- __shared__ float2 meta[parallel_blocks];
675
- if (tid < 2*parallel_blocks) {
676
- ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
677
- }
678
-
679
- __syncthreads();
680
-
681
- float kqmax = meta[0].x;
682
- #pragma unroll
683
- for (int l = 1; l < parallel_blocks; ++l) {
684
- kqmax = max(kqmax, meta[l].x);
685
- }
686
-
687
- float VKQ_numerator = 0.0f;
688
- float VKQ_denominator = 0.0f;
689
- #pragma unroll
690
- for (int l = 0; l < parallel_blocks; ++l) {
691
- const float diff = meta[l].x - kqmax;
692
- const float KQ_max_scale = expf(diff);
693
- const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
694
- *((uint32_t *) &KQ_max_scale) &= ftz_mask;
695
-
696
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
697
- VKQ_denominator += KQ_max_scale * meta[l].y;
698
- }
699
-
700
- dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
701
- #else
702
- NO_DEVICE_CODE;
703
- #endif // FP16_AVAILABLE
704
- }
705
-
706
  constexpr int get_max_power_of_2(int x) {
707
  return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
708
  }
@@ -727,66 +437,6 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
727
  static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
728
  static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
729
 
730
- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
731
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
732
- ggml_cuda_pool & pool, cudaStream_t main_stream
733
- ) {
734
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
735
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
736
-
737
- if (parallel_blocks > 1) {
738
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
739
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
740
- }
741
-
742
- constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
743
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
744
- const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
745
- const int shmem = 0;
746
-
747
- float scale = 1.0f;
748
- float max_bias = 0.0f;
749
-
750
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
751
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
752
-
753
- const uint32_t n_head = Q->ne[2];
754
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
755
-
756
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
757
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
758
-
759
- flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
760
- <<<blocks_num, block_dim, shmem, main_stream>>> (
761
- (const char *) Q->data,
762
- (const char *) K->data,
763
- (const char *) V->data,
764
- mask ? ((const char *) mask->data) : nullptr,
765
- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
766
- scale, max_bias, m0, m1, n_head_log2,
767
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
768
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
769
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
770
- Q->nb[1], Q->nb[2], Q->nb[3],
771
- K->nb[1], K->nb[2], K->nb[3],
772
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
773
- );
774
- CUDA_CHECK(cudaGetLastError());
775
-
776
- if (parallel_blocks == 1) {
777
- return;
778
- }
779
-
780
- const dim3 block_dim_combine(D, 1, 1);
781
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
782
- const int shmem_combine = 0;
783
-
784
- flash_attn_combine_results<D, parallel_blocks>
785
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
786
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
787
- CUDA_CHECK(cudaGetLastError());
788
- }
789
-
790
  template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
791
  const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
792
  ggml_cuda_pool & pool, cudaStream_t main_stream
@@ -891,95 +541,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
891
 
892
  const int32_t precision = KQV->op_params[2];
893
 
894
- if (!fp16_mma_available(cc)) {
895
- GGML_ASSERT(precision == GGML_PREC_DEFAULT);
896
- GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
897
-
898
- if (Q->ne[1] == 1) {
899
- constexpr int cols_per_block = 1;
900
- constexpr int parallel_blocks = 4;
901
- switch (Q->ne[0]) {
902
- case 64:
903
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
904
- break;
905
- case 128:
906
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
907
- break;
908
- default:
909
- GGML_ASSERT(false);
910
- break;
911
- }
912
- return;
913
- }
914
-
915
- if (Q->ne[1] == 2) {
916
- constexpr int cols_per_block = 2;
917
- constexpr int parallel_blocks = 4;
918
- switch (Q->ne[0]) {
919
- case 64:
920
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
921
- break;
922
- case 128:
923
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
924
- break;
925
- default:
926
- GGML_ASSERT(false);
927
- break;
928
- }
929
- return;
930
- }
931
-
932
- if (Q->ne[1] <= 4) {
933
- constexpr int cols_per_block = 4;
934
- constexpr int parallel_blocks = 4;
935
- switch (Q->ne[0]) {
936
- case 64:
937
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
938
- break;
939
- case 128:
940
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
941
- break;
942
- default:
943
- GGML_ASSERT(false);
944
- break;
945
- }
946
- return;
947
- }
948
-
949
- if (Q->ne[1] <= 8) {
950
- constexpr int cols_per_block = 8;
951
- constexpr int parallel_blocks = 4;
952
- switch (Q->ne[0]) {
953
- case 64:
954
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
955
- break;
956
- case 128:
957
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
958
- break;
959
- default:
960
- GGML_ASSERT(false);
961
- break;
962
- }
963
- return;
964
- }
965
 
966
- constexpr int cols_per_block = 8;
967
- constexpr int parallel_blocks = 1;
968
- switch (Q->ne[0]) {
969
- case 64:
970
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
971
- break;
972
- case 128:
973
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
974
- break;
975
- default:
976
- GGML_ASSERT(false);
977
- break;
978
- }
979
  return;
980
  }
981
 
982
  if (precision != GGML_PREC_DEFAULT) {
 
 
 
 
 
983
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
984
  constexpr int cols_per_block = 16;
985
  constexpr int nwarps = 4;
@@ -1037,22 +614,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
1037
  }
1038
 
1039
  if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
1040
- constexpr int cols_per_block = 1;
1041
- constexpr int parallel_blocks = 4;
1042
- switch (Q->ne[0]) {
1043
- case 64:
1044
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
1045
- break;
1046
- case 128:
1047
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
1048
- break;
1049
- case 256:
1050
- launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
1051
- break;
1052
- default:
1053
- GGML_ASSERT(false);
1054
- break;
1055
- }
1056
  return;
1057
  }
1058
 
 
1
  #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-vec-f16.cuh"
4
+ #include "fattn-vec-f32.cuh"
5
  #include "fattn.cuh"
6
 
7
  #include <cstdint>
 
10
  #include <mma.h>
11
  #endif
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
14
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
15
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 
413
  #endif // FP16_MMA_AVAILABLE
414
  }
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  constexpr int get_max_power_of_2(int x) {
417
  return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
418
  }
 
437
  static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
438
  static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
441
  const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
442
  ggml_cuda_pool & pool, cudaStream_t main_stream
 
541
 
542
  const int32_t precision = KQV->op_params[2];
543
 
544
+ if (!fast_fp16_available(cc)) {
545
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
546
+ return;
547
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
+ if (!fp16_mma_available(cc)) {
550
+ ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
 
 
 
 
 
 
 
 
 
 
 
551
  return;
552
  }
553
 
554
  if (precision != GGML_PREC_DEFAULT) {
555
+ if (Q->ne[1] == 1 && (Q->ne[0] == 64 || Q->ne[0] == 128)) {
556
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
557
+ return;
558
+ }
559
+
560
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
561
  constexpr int cols_per_block = 16;
562
  constexpr int nwarps = 4;
 
614
  }
615
 
616
  if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
617
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  return;
619
  }
620